chore: import deployable mai-bot source tree
This commit is contained in:
89
pytests/common_test/test_chat_config_utils.py
Normal file
89
pytests/common_test/test_chat_config_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.common.utils.utils_config import ChatConfigUtils, ExpressionConfigUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def test_get_chat_prompt_for_chat_merges_multiple_matching_prompts(monkeypatch):
|
||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828")
|
||||
monkeypatch.setattr(
|
||||
global_config.chat,
|
||||
"chat_prompts",
|
||||
[
|
||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "你也是群管理员,可以适当进行管理"},
|
||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "这个群是技术实验群,请你专心讨论技术"},
|
||||
{"platform": "qq", "item_id": "other", "rule_type": "group", "prompt": "不应该生效"},
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(chat_manager, "get_session_by_session_id", lambda _session_id: None)
|
||||
|
||||
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
|
||||
|
||||
assert result == "你也是群管理员,可以适当进行管理\n这个群是技术实验群,请你专心讨论技术"
|
||||
|
||||
|
||||
def test_get_chat_prompt_for_chat_matches_routed_session_by_chat_stream(monkeypatch):
|
||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
||||
monkeypatch.setattr(
|
||||
global_config.chat,
|
||||
"chat_prompts",
|
||||
[
|
||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "prompt": "路由会话也应该生效"},
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
||||
)
|
||||
|
||||
result = ChatConfigUtils.get_chat_prompt_for_chat(session_id, True)
|
||||
|
||||
assert result == "路由会话也应该生效"
|
||||
|
||||
|
||||
def test_expression_learning_list_matches_routed_session_by_chat_stream(monkeypatch):
|
||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
||||
monkeypatch.setattr(
|
||||
global_config.expression,
|
||||
"learning_list",
|
||||
[
|
||||
{
|
||||
"platform": "qq",
|
||||
"item_id": "1036092828",
|
||||
"rule_type": "group",
|
||||
"use_expression": False,
|
||||
"enable_learning": False,
|
||||
"enable_jargon_learning": True,
|
||||
}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
||||
)
|
||||
|
||||
assert ExpressionConfigUtils.get_expression_config_for_chat(session_id) == (False, False, True)
|
||||
|
||||
|
||||
def test_talk_value_rules_match_routed_session_by_chat_stream(monkeypatch):
|
||||
session_id = SessionUtils.calculate_session_id("qq", group_id="1036092828", account_id="bot-a")
|
||||
monkeypatch.setattr(global_config.chat, "talk_value", 0.1)
|
||||
monkeypatch.setattr(global_config.chat, "enable_talk_value_rules", True)
|
||||
monkeypatch.setattr(
|
||||
global_config.chat,
|
||||
"talk_value_rules",
|
||||
[
|
||||
{"platform": "qq", "item_id": "1036092828", "rule_type": "group", "time": "00:00-23:59", "value": 0.7}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
chat_manager,
|
||||
"get_session_by_session_id",
|
||||
lambda _session_id: SimpleNamespace(platform="qq", group_id="1036092828", user_id=None),
|
||||
)
|
||||
|
||||
assert ChatConfigUtils.get_talk_value(session_id, True) == 0.7
|
||||
908
pytests/common_test/test_database_migration_foundation.py
Normal file
908
pytests/common_test/test_database_migration_foundation.py
Normal file
@@ -0,0 +1,908 @@
|
||||
"""数据库迁移基础设施测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
import json
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import (
|
||||
BaseSchemaVersionDetector,
|
||||
BaseMigrationProgressReporter,
|
||||
DatabaseSchemaSnapshot,
|
||||
DatabaseMigrationBootstrapper,
|
||||
DatabaseMigrationState,
|
||||
DatabaseMigrationManager,
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationRegistry,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionResolver,
|
||||
SchemaVersionSource,
|
||||
SQLiteSchemaInspector,
|
||||
SQLiteUserVersionStore,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
create_database_migration_bootstrapper,
|
||||
)
|
||||
|
||||
|
||||
class FixedVersionDetector(BaseSchemaVersionDetector):
|
||||
"""测试用固定版本探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回测试探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 探测器名称。
|
||||
"""
|
||||
return "fixed_version_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据测试表是否存在返回固定版本。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
||||
"""
|
||||
if snapshot.has_table("legacy_records"):
|
||||
return 2
|
||||
return None
|
||||
|
||||
|
||||
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""测试用迁移进度上报器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试用进度上报器。"""
|
||||
self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""记录打开事件。"""
|
||||
self.events.append(("open", None, None, None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""记录关闭事件。"""
|
||||
self.events.append(("close", None, None, None))
|
||||
|
||||
def start(
|
||||
self,
|
||||
total_records: int,
|
||||
total_tables: int,
|
||||
description: str = "总迁移进度",
|
||||
table_unit_name: str = "表",
|
||||
record_unit_name: str = "记录",
|
||||
) -> None:
|
||||
"""记录启动事件。
|
||||
|
||||
Args:
|
||||
total_records: 任务记录总数。
|
||||
total_tables: 任务表总数。
|
||||
description: 任务描述。
|
||||
table_unit_name: 表级进度单位名称。
|
||||
record_unit_name: 记录级进度单位名称。
|
||||
"""
|
||||
del table_unit_name, record_unit_name
|
||||
self.events.append(("start", total_records, total_tables, description))
|
||||
|
||||
def advance(
|
||||
self,
|
||||
records: int = 0,
|
||||
completed_tables: int = 0,
|
||||
item_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""记录推进事件。
|
||||
|
||||
Args:
|
||||
records: 推进的记录数。
|
||||
completed_tables: 已完成的表数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
self.events.append(("advance", records, completed_tables, item_name))
|
||||
|
||||
|
||||
def _create_sqlite_engine(database_file: Path) -> Engine:
|
||||
"""创建测试用 SQLite 引擎。
|
||||
|
||||
Args:
|
||||
database_file: 测试数据库文件路径。
|
||||
|
||||
Returns:
|
||||
Engine: SQLite 引擎实例。
|
||||
"""
|
||||
return create_engine(
|
||||
f"sqlite:///{database_file}",
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
def _create_current_schema(connection: Connection) -> None:
|
||||
"""创建当前最新版本的数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(connection)
|
||||
|
||||
|
||||
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
||||
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE chat_streams (
|
||||
id INTEGER PRIMARY KEY,
|
||||
stream_id TEXT NOT NULL,
|
||||
create_time REAL NOT NULL,
|
||||
last_active_time REAL NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
group_id TEXT,
|
||||
group_name TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY,
|
||||
message_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_info_platform TEXT,
|
||||
user_id TEXT,
|
||||
user_nickname TEXT,
|
||||
chat_info_group_id TEXT,
|
||||
chat_info_group_name TEXT,
|
||||
is_mentioned INTEGER,
|
||||
is_at INTEGER,
|
||||
processed_plain_text TEXT,
|
||||
display_message TEXT,
|
||||
is_emoji INTEGER,
|
||||
is_picid INTEGER,
|
||||
is_command INTEGER,
|
||||
is_notify INTEGER,
|
||||
additional_config TEXT,
|
||||
priority_mode TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE action_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
action_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
action_reasoning TEXT,
|
||||
action_name TEXT NOT NULL,
|
||||
action_data TEXT,
|
||||
action_prompt_display TEXT,
|
||||
chat_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE expression (
|
||||
id INTEGER PRIMARY KEY,
|
||||
situation TEXT NOT NULL,
|
||||
style TEXT NOT NULL,
|
||||
content_list TEXT,
|
||||
count INTEGER,
|
||||
last_active_time REAL NOT NULL,
|
||||
chat_id TEXT,
|
||||
create_date REAL,
|
||||
checked INTEGER,
|
||||
rejected INTEGER,
|
||||
modified_by TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE jargon (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
raw_content TEXT,
|
||||
meaning TEXT,
|
||||
chat_id TEXT,
|
||||
is_global INTEGER,
|
||||
count INTEGER,
|
||||
is_jargon INTEGER,
|
||||
last_inference_count INTEGER,
|
||||
is_complete INTEGER,
|
||||
inference_with_context TEXT,
|
||||
inference_content_only TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO chat_streams (
|
||||
id,
|
||||
stream_id,
|
||||
create_time,
|
||||
last_active_time,
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
group_name
|
||||
) VALUES (
|
||||
1,
|
||||
'session-1',
|
||||
1710000000.0,
|
||||
1710000300.0,
|
||||
'qq',
|
||||
'user-1',
|
||||
'group-1',
|
||||
'测试群'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
message_id,
|
||||
time,
|
||||
chat_id,
|
||||
chat_info_platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
chat_info_group_id,
|
||||
chat_info_group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
processed_plain_text,
|
||||
display_message,
|
||||
is_emoji,
|
||||
is_picid,
|
||||
is_command,
|
||||
is_notify,
|
||||
additional_config,
|
||||
priority_mode
|
||||
) VALUES (
|
||||
1,
|
||||
'msg-1',
|
||||
1710000010.0,
|
||||
'session-1',
|
||||
'qq',
|
||||
'user-1',
|
||||
'测试用户',
|
||||
'group-1',
|
||||
'测试群',
|
||||
1,
|
||||
0,
|
||||
'你好',
|
||||
'你好呀',
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
'{"source":"legacy"}',
|
||||
'high'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO action_records (
|
||||
id,
|
||||
action_id,
|
||||
time,
|
||||
action_reasoning,
|
||||
action_name,
|
||||
action_data,
|
||||
action_prompt_display,
|
||||
chat_id
|
||||
) VALUES (
|
||||
1,
|
||||
'action-1',
|
||||
1710000020.0,
|
||||
'需要调用工具',
|
||||
'search',
|
||||
'{"query":"MaiBot"}',
|
||||
'执行搜索',
|
||||
'session-1'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO expression (
|
||||
id,
|
||||
situation,
|
||||
style,
|
||||
content_list,
|
||||
count,
|
||||
last_active_time,
|
||||
chat_id,
|
||||
create_date,
|
||||
checked,
|
||||
rejected,
|
||||
modified_by
|
||||
) VALUES (
|
||||
1,
|
||||
'打招呼',
|
||||
'可爱',
|
||||
'["你好呀","早上好"]',
|
||||
3,
|
||||
1710000030.0,
|
||||
'session-1',
|
||||
1710000040.0,
|
||||
1,
|
||||
0,
|
||||
'ai'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO jargon (
|
||||
id,
|
||||
content,
|
||||
raw_content,
|
||||
meaning,
|
||||
chat_id,
|
||||
is_global,
|
||||
count,
|
||||
is_jargon,
|
||||
last_inference_count,
|
||||
is_complete,
|
||||
inference_with_context,
|
||||
inference_content_only
|
||||
) VALUES (
|
||||
1,
|
||||
'上分',
|
||||
'["上分"]',
|
||||
'提高排名',
|
||||
'session-1',
|
||||
0,
|
||||
5,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
'{"guess":"context"}',
|
||||
'{"guess":"content"}'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
||||
"""应支持读取与写入 SQLite ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
||||
version_store = SQLiteUserVersionStore()
|
||||
|
||||
with engine.begin() as connection:
|
||||
assert version_store.read_version(connection) == 0
|
||||
version_store.write_version(connection, 7)
|
||||
|
||||
with engine.connect() as connection:
|
||||
assert version_store.read_version(connection) == 7
|
||||
|
||||
|
||||
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
||||
"""应能提取 SQLite 数据库的表与列结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
||||
inspector = SQLiteSchemaInspector()
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE legacy_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = inspector.inspect(connection)
|
||||
|
||||
assert snapshot.has_table("legacy_records")
|
||||
assert snapshot.has_column("legacy_records", "payload")
|
||||
assert not snapshot.has_column("legacy_records", "missing_column")
|
||||
table_schema = snapshot.get_table("legacy_records")
|
||||
|
||||
assert table_schema is not None
|
||||
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
||||
|
||||
|
||||
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
||||
"""空数据库应被解析为版本 ``0``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
||||
resolver = SchemaVersionResolver()
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 0
|
||||
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
||||
assert resolved_version.snapshot is not None
|
||||
assert resolved_version.snapshot.is_empty()
|
||||
|
||||
|
||||
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
||||
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
||||
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 2
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "fixed_version_detector"
|
||||
|
||||
|
||||
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
||||
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
||||
executed_steps: List[str] = []
|
||||
|
||||
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 0 -> 1。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
||||
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 1 -> 2。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
||||
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=0,
|
||||
version_to=1,
|
||||
name="create_sample_records",
|
||||
description="创建示例表。",
|
||||
handler=migrate_0_to_1,
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="add_sample_email",
|
||||
description="为示例表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
),
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 2
|
||||
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
version_store = SQLiteUserVersionStore()
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = version_store.read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("sample_records")
|
||||
assert snapshot.has_column("sample_records", "email")
|
||||
|
||||
|
||||
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
||||
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
context.start_progress(total_tables=3, total_records=30, description="总迁移进度")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="mai_messages")
|
||||
context.advance_progress(records=10, completed_tables=1, item_name="tool_records")
|
||||
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="progress_step",
|
||||
description="测试进度上报。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=registry,
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
assert reporter_instances[0].events == [
|
||||
("open", None, None, None),
|
||||
("start", 30, 3, "总迁移进度"),
|
||||
("advance", 10, 1, "chat_sessions"),
|
||||
("advance", 10, 1, "mai_messages"),
|
||||
("advance", 10, 1, "tool_records"),
|
||||
("close", None, None, None),
|
||||
]
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "latest_schema_detector"
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
||||
"""空库在建表完成后应回写最新 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
||||
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
||||
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
||||
execution_marks: List[str] = []
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试桥接器迁移步骤 ``1 -> 2``。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
||||
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
||||
)
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="bootstrap_add_email",
|
||||
description="为桥接器测试表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
bootstrapper = DatabaseMigrationBootstrapper(
|
||||
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
||||
latest_schema_version=2,
|
||||
)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert migration_state.resolved_version.version == 2
|
||||
assert migration_state.target_version == 2
|
||||
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("bootstrap_records")
|
||||
assert snapshot.has_column("bootstrap_records", "email")
|
||||
|
||||
|
||||
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
message_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, processed_plain_text, additional_config, raw_content
|
||||
FROM mai_messages
|
||||
WHERE message_id = 'msg-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
tool_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, tool_name, tool_display_prompt
|
||||
FROM tool_records
|
||||
WHERE tool_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
expression_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, content_list, modified_by
|
||||
FROM expressions
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
jargon_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id_dict, raw_content, inference_with_content_only
|
||||
FROM jargons
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
assert snapshot.has_table("__legacy_v1_messages")
|
||||
assert snapshot.has_table("chat_sessions")
|
||||
assert snapshot.has_table("mai_messages")
|
||||
assert snapshot.has_table("tool_records")
|
||||
assert not snapshot.has_table("action_records")
|
||||
assert not snapshot.has_column("mai_messages", "display_message")
|
||||
|
||||
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
||||
additional_config = json.loads(message_row["additional_config"])
|
||||
expression_content_list = json.loads(expression_row["content_list"])
|
||||
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
||||
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
||||
|
||||
assert message_row["session_id"] == "session-1"
|
||||
assert message_row["processed_plain_text"] == "你好"
|
||||
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
||||
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
||||
assert tool_row["session_id"] == "session-1"
|
||||
assert tool_row["tool_name"] == "search"
|
||||
assert tool_row["tool_display_prompt"] == "执行搜索"
|
||||
assert expression_row["session_id"] == "session-1"
|
||||
assert expression_row["modified_by"] == "AI"
|
||||
assert expression_content_list == ["你好呀", "早上好"]
|
||||
assert jargon_session_id_dict == {"session-1": 5}
|
||||
assert jargon_raw_content == ["上分"]
|
||||
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
||||
|
||||
|
||||
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
||||
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=build_default_migration_registry(),
|
||||
resolver=build_default_schema_version_resolver(),
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
||||
|
||||
assert migration_plan.step_count() == 3
|
||||
assert len(reporter_instances) == 3
|
||||
reporter_events = reporter_instances[0].events
|
||||
|
||||
assert reporter_events[0] == ("open", None, None, None)
|
||||
assert reporter_events[1] == ("start", 6, 12, "总迁移进度")
|
||||
assert reporter_events[-1] == ("close", None, None, None)
|
||||
assert reporter_events.count(("advance", 1, 0, None)) == 6
|
||||
assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1
|
||||
assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1
|
||||
assert len([event for event in reporter_events if event[0] == "advance"]) == 18
|
||||
|
||||
|
||||
def test_initialize_database_calls_bootstrapper_before_create_all(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
||||
call_order: List[str] = []
|
||||
|
||||
def _fake_prepare_database() -> DatabaseMigrationState:
|
||||
"""返回测试用迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
||||
"""
|
||||
call_order.append("prepare_database")
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
plan=MigrationPlan(
|
||||
current_version=EMPTY_SCHEMA_VERSION,
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
steps=[],
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_create_all(bind) -> None:
|
||||
"""记录建表调用。
|
||||
|
||||
Args:
|
||||
bind: 传入的数据库绑定对象。
|
||||
"""
|
||||
del bind
|
||||
call_order.append("create_all")
|
||||
|
||||
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
||||
"""记录迁移收尾调用。
|
||||
|
||||
Args:
|
||||
migration_state: 当前数据库迁移状态。
|
||||
"""
|
||||
del migration_state
|
||||
call_order.append("finalize_database")
|
||||
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False)
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
||||
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
||||
|
||||
database_module.initialize_database()
|
||||
|
||||
assert call_order == [
|
||||
"prepare_database",
|
||||
"create_all",
|
||||
"finalize_database",
|
||||
]
|
||||
81
pytests/common_test/test_expression_learner.py
Normal file
81
pytests/common_test/test_expression_learner.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""测试表达方式学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.bw_learner.expression_learner import ExpressionLearner
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_learner_engine")
|
||||
def expression_learner_engine_fixture() -> Generator:
|
||||
"""创建用于表达方式学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_find_similar_expression_uses_read_only_session_and_history_content(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
expression_learner_engine,
|
||||
) -> None:
|
||||
"""查找相似表达方式时,应能在离开会话后安全使用结果,并比较历史情景内容。"""
|
||||
import src.bw_learner.expression_learner as expression_learner_module
|
||||
|
||||
with Session(expression_learner_engine) as session:
|
||||
session.add(
|
||||
Expression(
|
||||
situation="发送汗滴表情",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(expression_learner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(expression_learner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
learner = ExpressionLearner(session_id="session-a")
|
||||
result = learner._find_similar_expression("表达情绪高涨或生理反应")
|
||||
|
||||
assert result is not None
|
||||
expression, similarity = result
|
||||
assert expression.item_id is not None
|
||||
assert expression.style == "发送💦表情符号"
|
||||
assert similarity == pytest.approx(1.0)
|
||||
78
pytests/common_test/test_expression_schema.py
Normal file
78
pytests/common_test/test_expression_schema.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""测试表达方式表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Expression
|
||||
|
||||
|
||||
@pytest.fixture(name="expression_engine")
|
||||
def expression_engine_fixture() -> Generator:
|
||||
"""创建仅用于表达方式表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_expression_insert_assigns_auto_increment_id(expression_engine) -> None:
|
||||
"""表达方式表在新库中应能自动分配自增主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
expression = Expression(
|
||||
situation="表达情绪高涨或生理反应",
|
||||
style="发送💦表情符号",
|
||||
content_list='["表达情绪高涨或生理反应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
session.add(expression)
|
||||
session.commit()
|
||||
session.refresh(expression)
|
||||
|
||||
assert expression.id is not None
|
||||
assert expression.id > 0
|
||||
|
||||
|
||||
def test_expression_insert_allows_same_situation_style(expression_engine) -> None:
|
||||
"""相同情景和风格的表达方式记录不应再被错误绑定到复合主键。"""
|
||||
with Session(expression_engine) as session:
|
||||
first_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应"]',
|
||||
count=1,
|
||||
session_id="session-a",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
second_expression = Expression(
|
||||
situation="对重复行为的默契响应",
|
||||
style="持续性跟发相同内容",
|
||||
content_list='["对重复行为的默契响应-变体"]',
|
||||
count=2,
|
||||
session_id="session-b",
|
||||
checked=False,
|
||||
rejected=False,
|
||||
)
|
||||
|
||||
session.add(first_expression)
|
||||
session.add(second_expression)
|
||||
session.commit()
|
||||
session.refresh(first_expression)
|
||||
session.refresh(second_expression)
|
||||
|
||||
assert first_expression.id is not None
|
||||
assert second_expression.id is not None
|
||||
assert first_expression.id != second_expression.id
|
||||
90
pytests/common_test/test_jargon_miner.py
Normal file
90
pytests/common_test/test_jargon_miner.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""测试黑话学习器的数据库读取行为。"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine, select
|
||||
|
||||
from src.bw_learner.jargon_miner import JargonMiner
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_miner_engine")
|
||||
def jargon_miner_engine_fixture() -> Generator:
|
||||
"""创建用于黑话学习器测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_extracted_entries_updates_existing_jargon_without_detached_session(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
jargon_miner_engine,
|
||||
) -> None:
|
||||
"""更新已有黑话时,不应因会话关闭导致 ORM 实例失效。"""
|
||||
import src.bw_learner.jargon_miner as jargon_miner_module
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
session.add(
|
||||
Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=0,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@contextmanager
|
||||
def fake_get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
"""构造带自动提交语义的测试会话工厂。
|
||||
|
||||
Args:
|
||||
auto_commit: 退出上下文时是否自动提交。
|
||||
|
||||
Yields:
|
||||
Generator[Session, None, None]: SQLModel 会话对象。
|
||||
"""
|
||||
session = Session(jargon_miner_engine)
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
monkeypatch.setattr(jargon_miner_module, "get_db_session", fake_get_db_session)
|
||||
|
||||
jargon_miner = JargonMiner(session_id="session-a", session_name="测试群")
|
||||
await jargon_miner.process_extracted_entries(
|
||||
[{"content": "VF8V4L", "raw_content": {"[2] second"}}],
|
||||
)
|
||||
|
||||
with Session(jargon_miner_engine) as session:
|
||||
db_jargon = session.exec(select(Jargon).where(Jargon.content == "VF8V4L")).one()
|
||||
|
||||
assert db_jargon.count == 1
|
||||
assert db_jargon.session_id_dict == '{"session-a": 2}'
|
||||
assert sorted(db_jargon.raw_content and __import__("json").loads(db_jargon.raw_content)) == [
|
||||
"[1] first",
|
||||
"[2] second",
|
||||
]
|
||||
84
pytests/common_test/test_jargon_schema.py
Normal file
84
pytests/common_test/test_jargon_schema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""测试黑话表结构和基础插入行为。"""
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from src.common.database.database_model import Jargon
|
||||
|
||||
|
||||
@pytest.fixture(name="jargon_engine")
|
||||
def jargon_engine_fixture() -> Generator:
|
||||
"""创建仅用于黑话表测试的内存数据库引擎。
|
||||
|
||||
Yields:
|
||||
Generator: 供测试使用的 SQLite 内存引擎。
|
||||
"""
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield engine
|
||||
|
||||
|
||||
def test_jargon_insert_assigns_auto_increment_id(jargon_engine) -> None:
|
||||
"""黑话表在新库中应能自动分配自增主键。"""
|
||||
with Session(jargon_engine) as session:
|
||||
jargon = Jargon(
|
||||
content="VF8V4L",
|
||||
raw_content='["[1] test"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=True,
|
||||
last_inference_count=0,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.commit()
|
||||
session.refresh(jargon)
|
||||
|
||||
assert jargon.id is not None
|
||||
assert jargon.id > 0
|
||||
|
||||
|
||||
def test_jargon_insert_allows_same_content_with_different_rows(jargon_engine) -> None:
|
||||
"""黑话内容不应再被错误地绑成复合主键的一部分。"""
|
||||
with Session(jargon_engine) as session:
|
||||
first_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] first"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-a": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
second_jargon = Jargon(
|
||||
content="表情1",
|
||||
raw_content='["[1] second"]',
|
||||
meaning="",
|
||||
session_id_dict='{"session-b": 1}',
|
||||
count=1,
|
||||
is_jargon=True,
|
||||
is_complete=False,
|
||||
is_global=False,
|
||||
last_inference_count=0,
|
||||
)
|
||||
|
||||
session.add(first_jargon)
|
||||
session.add(second_jargon)
|
||||
session.commit()
|
||||
session.refresh(first_jargon)
|
||||
session.refresh(second_jargon)
|
||||
|
||||
assert first_jargon.id is not None
|
||||
assert second_jargon.id is not None
|
||||
assert first_jargon.id != second_jargon.id
|
||||
135
pytests/common_test/test_maisaka_expression_selector.py
Normal file
135
pytests/common_test/test_maisaka_expression_selector.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import src.chat.replyer.maisaka_expression_selector as selector_module
|
||||
from src.chat.replyer.maisaka_expression_selector import MaisakaExpressionSelector
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
|
||||
def _build_target(platform: str, item_id: str, rule_type: str = "group") -> SimpleNamespace:
|
||||
return SimpleNamespace(platform=platform, item_id=item_id, rule_type=rule_type)
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_returns_related_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("qq", "10001"),
|
||||
_build_target("qq", "10002"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id, related_session_id}
|
||||
assert has_global_share is False
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_matches_routed_sessions(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001", account_id="bot-a")
|
||||
related_session_id = SessionUtils.calculate_session_id("qq", group_id="10002", account_id="bot-a")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("qq", "10001"),
|
||||
_build_target("qq", "10002"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
selector_module.ChatConfigUtils,
|
||||
"_get_chat_stream",
|
||||
lambda session_id: SimpleNamespace(platform="qq", group_id="10001", user_id=None)
|
||||
if session_id == current_session_id
|
||||
else None,
|
||||
)
|
||||
target_session_ids = {
|
||||
"10001": current_session_id,
|
||||
"10002": related_session_id,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
selector_module.ChatConfigUtils,
|
||||
"get_target_session_ids",
|
||||
lambda target_item: {target_session_ids[target_item.item_id]},
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id, related_session_id}
|
||||
assert has_global_share is False
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_uses_star_as_global_share(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("*", "*"),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id}
|
||||
assert has_global_share is True
|
||||
|
||||
|
||||
def test_resolve_expression_group_scope_does_not_treat_empty_target_as_global(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
current_session_id = SessionUtils.calculate_session_id("qq", group_id="10001")
|
||||
|
||||
monkeypatch.setattr(
|
||||
selector_module,
|
||||
"global_config",
|
||||
SimpleNamespace(
|
||||
expression=SimpleNamespace(
|
||||
expression_groups=[
|
||||
SimpleNamespace(
|
||||
expression_groups=[
|
||||
_build_target("", ""),
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
selector = MaisakaExpressionSelector()
|
||||
related_session_ids, has_global_share = selector._resolve_expression_group_scope(current_session_id)
|
||||
|
||||
assert related_session_ids == {current_session_id}
|
||||
assert has_global_share is False
|
||||
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
355
pytests/common_test/test_person_info_group_cardname.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""人物信息群名片字段兼容测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.person_info_data_model import dump_group_cardname_records, parse_group_cardname_json
|
||||
|
||||
|
||||
class _DummyLogger:
|
||||
"""模拟日志记录器。"""
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""记录调试日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""记录信息日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""记录警告日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""记录错误日志。
|
||||
|
||||
Args:
|
||||
message: 日志内容。
|
||||
"""
|
||||
del message
|
||||
|
||||
|
||||
class _DummyStatement:
|
||||
"""模拟 SQL 查询语句对象。"""
|
||||
|
||||
def where(self, condition: Any) -> "_DummyStatement":
|
||||
"""附加过滤条件。
|
||||
|
||||
Args:
|
||||
condition: 过滤条件。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del condition
|
||||
return self
|
||||
|
||||
def limit(self, value: int) -> "_DummyStatement":
|
||||
"""限制返回条数。
|
||||
|
||||
Args:
|
||||
value: 条数限制。
|
||||
|
||||
Returns:
|
||||
_DummyStatement: 当前语句对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
|
||||
class _DummyColumn:
|
||||
"""模拟 SQLModel 列对象。"""
|
||||
|
||||
def is_not(self, value: Any) -> "_DummyColumn":
|
||||
"""模拟 `IS NOT` 条件构造。
|
||||
|
||||
Args:
|
||||
value: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del value
|
||||
return self
|
||||
|
||||
def __eq__(self, other: Any) -> "_DummyColumn":
|
||||
"""模拟等值条件构造。
|
||||
|
||||
Args:
|
||||
other: 比较值。
|
||||
|
||||
Returns:
|
||||
_DummyColumn: 当前列对象。
|
||||
"""
|
||||
del other
|
||||
return self
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟数据库查询结果。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化查询结果。
|
||||
|
||||
Args:
|
||||
record: 待返回的首条记录。
|
||||
"""
|
||||
self._record = record
|
||||
|
||||
def first(self) -> Any:
|
||||
"""返回第一条记录。
|
||||
|
||||
Returns:
|
||||
Any: 首条记录。
|
||||
"""
|
||||
return self._record
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回全部结果。
|
||||
|
||||
Returns:
|
||||
list[Any]: 结果列表。
|
||||
"""
|
||||
if self._record is None:
|
||||
return []
|
||||
return self._record if isinstance(self._record, list) else [self._record]
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def __init__(self, record: Any) -> None:
|
||||
"""初始化 Session。
|
||||
|
||||
Args:
|
||||
record: `first()` 应返回的记录。
|
||||
"""
|
||||
self.record = record
|
||||
self.added_records: list[Any] = []
|
||||
|
||||
def __enter__(self) -> "_DummySession":
|
||||
"""进入上下文管理器。
|
||||
|
||||
Returns:
|
||||
_DummySession: 当前 Session。
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""退出上下文管理器。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_val: 异常值。
|
||||
exc_tb: 异常回溯。
|
||||
"""
|
||||
del exc_type
|
||||
del exc_val
|
||||
del exc_tb
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询。
|
||||
|
||||
Args:
|
||||
statement: 查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 模拟结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult(self.record)
|
||||
|
||||
def add(self, record: Any) -> None:
|
||||
"""记录被添加的对象。
|
||||
|
||||
Args:
|
||||
record: 被写入 Session 的对象。
|
||||
"""
|
||||
self.added_records.append(record)
|
||||
|
||||
|
||||
class _DummyPersonInfoRecord:
|
||||
"""模拟 `PersonInfo` ORM 模型。"""
|
||||
|
||||
person_id = "person_id"
|
||||
person_name = "person_name"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""使用关键字参数初始化记录对象。
|
||||
|
||||
Args:
|
||||
**kwargs: 字段值。
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def _load_person_module(monkeypatch: pytest.MonkeyPatch, session: _DummySession) -> ModuleType:
|
||||
"""加载带依赖桩的 `person_info` 模块。
|
||||
|
||||
Args:
|
||||
monkeypatch: Pytest monkeypatch 工具。
|
||||
session: 提供给模块使用的假数据库 Session。
|
||||
|
||||
Returns:
|
||||
ModuleType: 加载后的模块对象。
|
||||
"""
|
||||
logger_module = ModuleType("src.common.logger")
|
||||
logger_module.get_logger = lambda name: _DummyLogger()
|
||||
monkeypatch.setitem(sys.modules, "src.common.logger", logger_module)
|
||||
|
||||
database_module = ModuleType("src.common.database.database")
|
||||
database_module.get_db_session = lambda: session
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database", database_module)
|
||||
|
||||
database_model_module = ModuleType("src.common.database.database_model")
|
||||
database_model_module.PersonInfo = _DummyPersonInfoRecord
|
||||
monkeypatch.setitem(sys.modules, "src.common.database.database_model", database_model_module)
|
||||
|
||||
llm_module = ModuleType("src.llm_models.utils_model")
|
||||
|
||||
class _DummyLLMRequest:
|
||||
"""模拟 LLMRequest。"""
|
||||
|
||||
def __init__(self, model_set: Any, request_type: str) -> None:
|
||||
"""初始化假请求对象。
|
||||
|
||||
Args:
|
||||
model_set: 模型配置。
|
||||
request_type: 请求类型。
|
||||
"""
|
||||
del model_set
|
||||
del request_type
|
||||
|
||||
llm_module.LLMRequest = _DummyLLMRequest
|
||||
monkeypatch.setitem(sys.modules, "src.llm_models.utils_model", llm_module)
|
||||
|
||||
config_module = ModuleType("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(bot=SimpleNamespace(nickname="MaiBot"))
|
||||
config_module.model_config = SimpleNamespace(model_task_config=SimpleNamespace(tool_use="tool_use", utils="utils"))
|
||||
monkeypatch.setitem(sys.modules, "src.config.config", config_module)
|
||||
|
||||
chat_manager_module = ModuleType("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace()
|
||||
monkeypatch.setitem(sys.modules, "src.chat.message_receive.chat_manager", chat_manager_module)
|
||||
|
||||
module_path = Path(__file__).resolve().parents[2] / "src" / "person_info" / "person_info.py"
|
||||
spec = spec_from_file_location("person_info_group_cardname_test_module", module_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
|
||||
module = module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, spec.name, module)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
monkeypatch.setattr(module, "select", lambda *args: _DummyStatement())
|
||||
monkeypatch.setattr(module, "col", lambda field: _DummyColumn())
|
||||
return module
|
||||
|
||||
|
||||
def test_parse_group_cardname_json_uses_canonical_key() -> None:
|
||||
"""群名片 JSON 解析应只使用 `group_cardname` 键名。"""
|
||||
parsed = parse_group_cardname_json(
|
||||
json.dumps(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "现行字段"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert parsed is not None
|
||||
assert [(item.group_id, item.group_cardname) for item in parsed] == [
|
||||
("1001", "现行字段"),
|
||||
]
|
||||
|
||||
|
||||
def test_dump_group_cardname_records_uses_canonical_key() -> None:
|
||||
"""群名片序列化应输出 `group_cardname` 键名。"""
|
||||
dumped = dump_group_cardname_records(
|
||||
[
|
||||
{"group_id": "1001", "group_cardname": "群昵称"},
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(dumped) == [{"group_id": "1001", "group_cardname": "群昵称"}]
|
||||
|
||||
|
||||
def test_person_sync_to_database_uses_group_cardname_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""同步人物信息时应写入数据库模型的 `group_cardname` 字段。"""
|
||||
record = _DummyPersonInfoRecord()
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.is_known = True
|
||||
person.person_id = "person-1"
|
||||
person.platform = "qq"
|
||||
person.user_id = "10001"
|
||||
person.nickname = "看番的龙"
|
||||
person.person_name = "看番的龙"
|
||||
person.name_reason = "测试"
|
||||
person.know_times = 1
|
||||
person.know_since = 1700000000.0
|
||||
person.last_know = 1700000100.0
|
||||
person.memory_points = ["喜好:番剧:0.8"]
|
||||
person.group_cardname_list = [{"group_id": "20001", "group_cardname": "白泽大人"}]
|
||||
|
||||
person.sync_to_database()
|
||||
|
||||
assert record.group_cardname == '[{"group_id": "20001", "group_cardname": "白泽大人"}]'
|
||||
assert not hasattr(record, "group_nickname")
|
||||
|
||||
|
||||
def test_person_load_from_database_normalizes_group_cardname_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""从数据库加载人物信息时应读取标准 `group_cardname` 结构。"""
|
||||
record = _DummyPersonInfoRecord(
|
||||
user_id="10001",
|
||||
platform="qq",
|
||||
is_known=True,
|
||||
user_nickname="看番的龙",
|
||||
person_name="看番的龙",
|
||||
name_reason=None,
|
||||
know_counts=2,
|
||||
memory_points='["喜好:番剧:0.8"]',
|
||||
group_cardname=json.dumps(
|
||||
[
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
session = _DummySession(record)
|
||||
module = _load_person_module(monkeypatch, session)
|
||||
|
||||
person = module.Person.__new__(module.Person)
|
||||
person.person_id = "person-1"
|
||||
person.memory_points = []
|
||||
person.group_cardname_list = []
|
||||
|
||||
person.load_from_database()
|
||||
|
||||
assert person.group_cardname_list == [
|
||||
{"group_id": "20001", "group_cardname": "白泽大人"},
|
||||
]
|
||||
Reference in New Issue
Block a user