diff --git a/pytests/common_test/test_database_migration_foundation.py b/pytests/common_test/test_database_migration_foundation.py index 9c930744..ffec2b6a 100644 --- a/pytests/common_test/test_database_migration_foundation.py +++ b/pytests/common_test/test_database_migration_foundation.py @@ -68,7 +68,7 @@ class FakeMigrationProgressReporter(BaseMigrationProgressReporter): def __init__(self) -> None: """初始化测试用进度上报器。""" - self.events: List[Tuple[str, Optional[int], Optional[str], Optional[str]]] = [] + self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = [] def open(self) -> None: """记录打开事件。""" @@ -80,27 +80,38 @@ class FakeMigrationProgressReporter(BaseMigrationProgressReporter): def start( self, - total: int, + total_records: int, + total_tables: int, description: str = "总迁移进度", - unit_name: str = "表", + table_unit_name: str = "表", + record_unit_name: str = "记录", ) -> None: """记录启动事件。 Args: - total: 任务总数。 + total_records: 任务记录总数。 + total_tables: 任务表总数。 description: 任务描述。 - unit_name: 进度单位名称。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 """ - self.events.append(("start", total, description, unit_name)) + del table_unit_name, record_unit_name + self.events.append(("start", total_records, total_tables, description)) - def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: """记录推进事件。 Args: - advance: 推进步数。 + records: 推进的记录数。 + completed_tables: 已完成的表数。 item_name: 当前完成的项目名称。 """ - self.events.append(("advance", advance, item_name, None)) + self.events.append(("advance", records, completed_tables, item_name)) def _create_sqlite_engine(database_file: Path) -> Engine: @@ -550,10 +561,10 @@ def test_manager_can_report_step_progress(tmp_path: Path) -> None: Args: context: 当前迁移步骤执行上下文。 """ - context.start_progress(total=3, description="总迁移进度", unit_name="表") - context.advance_progress(item_name="chat_sessions") - context.advance_progress(item_name="mai_messages") - context.advance_progress(item_name="tool_records") + context.start_progress(total_tables=3, total_records=30, description="总迁移进度") + context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions") + context.advance_progress(records=10, completed_tables=1, item_name="mai_messages") + context.advance_progress(records=10, completed_tables=1, item_name="tool_records") context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")) with engine.begin() as connection: @@ -582,10 +593,10 @@ def test_manager_can_report_step_progress(tmp_path: Path) -> None: assert len(reporter_instances) == 1 assert reporter_instances[0].events == [ ("open", None, None, None), - ("start", 3, "总迁移进度", "表"), - ("advance", 1, "chat_sessions", None), - ("advance", 1, "mai_messages", None), - ("advance", 1, "tool_records", None), + ("start", 30, 3, "总迁移进度"), + ("advance", 10, 1, "chat_sessions"), + ("advance", 10, 1, "mai_messages"), + ("advance", 10, 1, "tool_records"), ("close", None, None, None), ] @@ -842,11 +853,12 @@ def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None: reporter_events = reporter_instances[0].events assert reporter_events[0] == ("open", None, None, None) - assert reporter_events[1] == ("start", 12, "总迁移进度", "表") + assert reporter_events[1] == ("start", 6, 12, "总迁移进度") assert reporter_events[-1] == ("close", None, None, None) - assert reporter_events.count(("advance", 1, "chat_sessions", None)) == 1 - assert reporter_events.count(("advance", 1, "thinking_questions", None)) == 1 - assert len([event for event in reporter_events if event[0] == "advance"]) == 12 + assert reporter_events.count(("advance", 1, 0, None)) == 6 + assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1 + assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1 + assert len([event for event in reporter_events if event[0] == "advance"]) == 18 def test_initialize_database_calls_bootstrapper_before_create_all( diff --git a/src/common/database/migrations/legacy_v1_to_v2.py b/src/common/database/migrations/legacy_v1_to_v2.py index 284da330..c1f88dd0 100644 --- a/src/common/database/migrations/legacy_v1_to_v2.py +++ b/src/common/database/migrations/legacy_v1_to_v2.py @@ -66,7 +66,7 @@ def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None: _rename_legacy_v1_tables(context.connection, snapshot) SQLModel.metadata.create_all(context.connection) - table_migration_jobs: List[Tuple[str, Callable[[Connection], int]]] = [ + table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [ ("chat_sessions", _migrate_chat_sessions), ("llm_usage", _migrate_model_usage), ("images", _migrate_images), @@ -81,10 +81,16 @@ def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None: ("thinking_questions", _migrate_thinking_questions), ] migrated_counts: Dict[str, int] = {} - context.start_progress(total=len(table_migration_jobs), description="总迁移进度", unit_name="表") + total_record_count = _estimate_total_record_count(context.connection) + context.start_progress( + total_tables=len(table_migration_jobs), + total_records=total_record_count, + description="总迁移进度", + table_unit_name="表", + record_unit_name="记录", + ) for table_name, migration_handler in table_migration_jobs: - migrated_counts[table_name] = migration_handler(context.connection) - context.advance_progress(item_name=table_name) + migrated_counts[table_name] = migration_handler(context) summary_text = ", ".join(f"{table_name}={count}" for table_name, count in migrated_counts.items()) logger.info(f"旧版数据库迁移完成: {summary_text}") @@ -535,17 +541,77 @@ def _deduce_image_type_name(value: Any) -> str: return "IMAGE" -def _migrate_chat_sessions(connection: Connection) -> int: - """迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``。 +def _count_legacy_table_rows(connection: Connection, original_table_name: str) -> int: + """统计单张旧版备份表中的记录总数。 + + Args: + connection: 当前数据库连接。 + original_table_name: 旧版原始表名。 + + Returns: + int: 备份表中的记录数;若表不存在则返回 ``0``。 + """ + backup_table_name = _legacy_backup_table_name(original_table_name) + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, backup_table_name): + return 0 + row = connection.execute( + text(f"SELECT COUNT(*) FROM {_quote_identifier(backup_table_name)}") + ).first() + if row is None: + return 0 + return _normalize_int(row[0], default=0) + + +def _estimate_total_record_count(connection: Connection) -> int: + """估算旧版迁移步骤需要处理的总记录数。 Args: connection: 当前数据库连接。 + Returns: + int: 本次迁移预计处理的总记录数。 + """ + return ( + _count_legacy_table_rows(connection, "chat_streams") + + _count_legacy_table_rows(connection, "llm_usage") + + _count_legacy_table_rows(connection, "emoji") + + _count_legacy_table_rows(connection, "images") + + _count_legacy_table_rows(connection, "messages") + + _count_legacy_table_rows(connection, "action_records") + + _count_legacy_table_rows(connection, "action_records") + + _count_legacy_table_rows(connection, "online_time") + + _count_legacy_table_rows(connection, "person_info") + + _count_legacy_table_rows(connection, "expression") + + _count_legacy_table_rows(connection, "jargon") + + _count_legacy_table_rows(connection, "chat_history") + + _count_legacy_table_rows(connection, "thinking_back") + ) + + +def _complete_table_progress(context: MigrationExecutionContext, table_name: str) -> None: + """标记单张表的迁移已经完成。 + + Args: + context: 当前迁移步骤执行上下文。 + table_name: 已完成迁移的表名。 + """ + context.advance_progress(completed_tables=1, item_name=table_name) + + +def _migrate_chat_sessions(context: MigrationExecutionContext) -> int: + """迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``。 + + Args: + context: 当前迁移步骤执行上下文。 + Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "chat_streams") if legacy_table is None: + _complete_table_progress(context, "chat_sessions") return 0 migrated_count = 0 @@ -570,34 +636,37 @@ def _migrate_chat_sessions(connection: Connection) -> int: ) for row in legacy_table.rows: session_id = _normalize_required_text(row.get("stream_id")) - if not session_id: - continue - connection.execute( - insert_sql, - { - "session_id": session_id, - "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), - "last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True), - "user_id": _normalize_optional_text(row.get("user_id")), - "group_id": _normalize_optional_text(row.get("group_id")), - "platform": _normalize_required_text(row.get("platform"), default="unknown"), - }, - ) - migrated_count += 1 + if session_id: + connection.execute( + insert_sql, + { + "session_id": session_id, + "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), + "last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True), + "user_id": _normalize_optional_text(row.get("user_id")), + "group_id": _normalize_optional_text(row.get("group_id")), + "platform": _normalize_required_text(row.get("platform"), default="unknown"), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "chat_sessions") return migrated_count -def _migrate_model_usage(connection: Connection) -> int: +def _migrate_model_usage(context: MigrationExecutionContext) -> int: """迁移旧版 ``llm_usage`` 到新版 ``llm_usage``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "llm_usage") if legacy_table is None: + _complete_table_progress(context, "llm_usage") return 0 migrated_count = 0 @@ -654,18 +723,21 @@ def _migrate_model_usage(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "llm_usage") return migrated_count -def _migrate_images(connection: Connection) -> int: +def _migrate_images(context: MigrationExecutionContext) -> int: """迁移旧版 ``emoji`` 与 ``images`` 到新版 ``images``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection migrated_count = 0 existing_keys: Set[Tuple[str, str, str]] = set() existing_rows = connection.execute( @@ -719,28 +791,28 @@ def _migrate_images(connection: Connection) -> int: full_path = _normalize_required_text(row.get("full_path")) image_hash = _normalize_required_text(row.get("emoji_hash")) dedupe_key = (full_path, image_hash, "EMOJI") - if not full_path or dedupe_key in existing_keys: - continue - connection.execute( - insert_sql, - { - "image_hash": image_hash, - "description": _normalize_required_text(row.get("description")), - "full_path": full_path, - "image_type": "EMOJI", - "emotion": _normalize_optional_text(row.get("emotion")), - "query_count": _normalize_int(row.get("query_count"), default=0), - "is_registered": _normalize_bool(row.get("is_registered"), default=False), - "is_banned": _normalize_bool(row.get("is_banned"), default=False), - "no_file_flag": False, - "record_time": _coerce_datetime(row.get("record_time"), fallback_now=True), - "register_time": _coerce_datetime(row.get("register_time")), - "last_used_time": _coerce_datetime(row.get("last_used_time")), - "vlm_processed": False, - }, - ) - existing_keys.add(dedupe_key) - migrated_count += 1 + if full_path and dedupe_key not in existing_keys: + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": "EMOJI", + "emotion": _normalize_optional_text(row.get("emotion")), + "query_count": _normalize_int(row.get("query_count"), default=0), + "is_registered": _normalize_bool(row.get("is_registered"), default=False), + "is_banned": _normalize_bool(row.get("is_banned"), default=False), + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("record_time"), fallback_now=True), + "register_time": _coerce_datetime(row.get("register_time")), + "last_used_time": _coerce_datetime(row.get("last_used_time")), + "vlm_processed": False, + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + context.advance_progress(records=1) legacy_images_table = _load_legacy_table_data(connection, "images") if legacy_images_table is not None: @@ -749,43 +821,46 @@ def _migrate_images(connection: Connection) -> int: image_hash = _normalize_required_text(row.get("emoji_hash")) image_type = _deduce_image_type_name(row.get("type")) dedupe_key = (full_path, image_hash, image_type) - if not full_path or dedupe_key in existing_keys: - continue - connection.execute( - insert_sql, - { - "image_hash": image_hash, - "description": _normalize_required_text(row.get("description")), - "full_path": full_path, - "image_type": image_type, - "emotion": None, - "query_count": _normalize_int(row.get("count"), default=0), - "is_registered": False, - "is_banned": False, - "no_file_flag": False, - "record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True), - "register_time": None, - "last_used_time": None, - "vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False), - }, - ) - existing_keys.add(dedupe_key) - migrated_count += 1 + if full_path and dedupe_key not in existing_keys: + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": image_type, + "emotion": None, + "query_count": _normalize_int(row.get("count"), default=0), + "is_registered": False, + "is_banned": False, + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "register_time": None, + "last_used_time": None, + "vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False), + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "images") return migrated_count -def _migrate_messages(connection: Connection) -> int: +def _migrate_messages(context: MigrationExecutionContext) -> int: """迁移旧版 ``messages`` 到新版 ``mai_messages``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "messages") if legacy_table is None: + _complete_table_progress(context, "mai_messages") return 0 migrated_count = 0 @@ -840,62 +915,65 @@ def _migrate_messages(connection: Connection) -> int: ) for row in legacy_table.rows: session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) - if not session_id: - continue - processed_plain_text = _normalize_optional_text(row.get("processed_plain_text")) - display_message = _normalize_optional_text(row.get("display_message")) - connection.execute( - insert_sql, - { - "id": row.get("id"), - "message_id": _normalize_required_text(row.get("message_id"), default=""), - "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), - "platform": _normalize_required_text( - row.get("chat_info_platform") or row.get("user_platform"), - default="unknown", - ), - "user_id": _normalize_required_text( - row.get("user_id") or row.get("chat_info_user_id"), - default="", - ), - "user_nickname": _normalize_required_text( - row.get("user_nickname") or row.get("chat_info_user_nickname"), - default="", - ), - "user_cardname": _normalize_optional_text( - row.get("user_cardname") or row.get("chat_info_user_cardname") - ), - "group_id": _normalize_optional_text(row.get("chat_info_group_id")), - "group_name": _normalize_optional_text(row.get("chat_info_group_name")), - "is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False), - "is_at": _normalize_bool(row.get("is_at"), default=False), - "session_id": session_id, - "reply_to": _normalize_optional_text(row.get("reply_to")), - "is_emoji": _normalize_bool(row.get("is_emoji"), default=False), - "is_picture": _normalize_bool(row.get("is_picid"), default=False), - "is_command": _normalize_bool(row.get("is_command"), default=False), - "is_notify": _normalize_bool(row.get("is_notify"), default=False), - "raw_content": _build_message_raw_content(processed_plain_text, display_message), - "processed_plain_text": processed_plain_text, - "display_message": display_message, - "additional_config": _build_legacy_message_additional_config(row), - }, - ) - migrated_count += 1 + if session_id: + processed_plain_text = _normalize_optional_text(row.get("processed_plain_text")) + display_message = _normalize_optional_text(row.get("display_message")) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "message_id": _normalize_required_text(row.get("message_id"), default=""), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "platform": _normalize_required_text( + row.get("chat_info_platform") or row.get("user_platform"), + default="unknown", + ), + "user_id": _normalize_required_text( + row.get("user_id") or row.get("chat_info_user_id"), + default="", + ), + "user_nickname": _normalize_required_text( + row.get("user_nickname") or row.get("chat_info_user_nickname"), + default="", + ), + "user_cardname": _normalize_optional_text( + row.get("user_cardname") or row.get("chat_info_user_cardname") + ), + "group_id": _normalize_optional_text(row.get("chat_info_group_id")), + "group_name": _normalize_optional_text(row.get("chat_info_group_name")), + "is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False), + "is_at": _normalize_bool(row.get("is_at"), default=False), + "session_id": session_id, + "reply_to": _normalize_optional_text(row.get("reply_to")), + "is_emoji": _normalize_bool(row.get("is_emoji"), default=False), + "is_picture": _normalize_bool(row.get("is_picid"), default=False), + "is_command": _normalize_bool(row.get("is_command"), default=False), + "is_notify": _normalize_bool(row.get("is_notify"), default=False), + "raw_content": _build_message_raw_content(processed_plain_text, display_message), + "processed_plain_text": processed_plain_text, + "display_message": display_message, + "additional_config": _build_legacy_message_additional_config(row), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "mai_messages") return migrated_count -def _migrate_action_records(connection: Connection) -> int: +def _migrate_action_records(context: MigrationExecutionContext) -> int: """迁移旧版 ``action_records`` 到新版 ``action_records``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "action_records") if legacy_table is None: + _complete_table_progress(context, "action_records") return 0 migrated_count = 0 @@ -926,37 +1004,40 @@ def _migrate_action_records(connection: Connection) -> int: ) for row in legacy_table.rows: session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) - if not session_id: - continue - connection.execute( - insert_sql, - { - "id": row.get("id"), - "action_id": _normalize_required_text(row.get("action_id")), - "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), - "session_id": session_id, - "action_name": _normalize_required_text(row.get("action_name"), default="unknown"), - "action_reasoning": _normalize_optional_text(row.get("action_reasoning")), - "action_data": _normalize_optional_text(row.get("action_data")), - "action_builtin_prompt": None, - "action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), - }, - ) - migrated_count += 1 + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "action_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "action_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "action_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "action_data": _normalize_optional_text(row.get("action_data")), + "action_builtin_prompt": None, + "action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "action_records") return migrated_count -def _migrate_tool_records(connection: Connection) -> int: +def _migrate_tool_records(context: MigrationExecutionContext) -> int: """迁移旧版 ``action_records`` 到新版 ``tool_records``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "action_records") if legacy_table is None: + _complete_table_progress(context, "tool_records") return 0 migrated_count = 0 @@ -987,37 +1068,40 @@ def _migrate_tool_records(connection: Connection) -> int: ) for row in legacy_table.rows: session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) - if not session_id: - continue - connection.execute( - insert_sql, - { - "id": row.get("id"), - "tool_id": _normalize_required_text(row.get("action_id")), - "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), - "session_id": session_id, - "tool_name": _normalize_required_text(row.get("action_name"), default="unknown"), - "tool_reasoning": _normalize_optional_text(row.get("action_reasoning")), - "tool_data": _normalize_optional_text(row.get("action_data")), - "tool_builtin_prompt": None, - "tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), - }, - ) - migrated_count += 1 + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "tool_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "tool_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "tool_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "tool_data": _normalize_optional_text(row.get("action_data")), + "tool_builtin_prompt": None, + "tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "tool_records") return migrated_count -def _migrate_online_time(connection: Connection) -> int: +def _migrate_online_time(context: MigrationExecutionContext) -> int: """迁移旧版 ``online_time`` 到新版 ``online_time``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "online_time") if legacy_table is None: + _complete_table_progress(context, "online_time") return 0 migrated_count = 0 @@ -1050,20 +1134,24 @@ def _migrate_online_time(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "online_time") return migrated_count -def _migrate_person_info(connection: Connection) -> int: +def _migrate_person_info(context: MigrationExecutionContext) -> int: """迁移旧版 ``person_info`` 到新版 ``person_info``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "person_info") if legacy_table is None: + _complete_table_progress(context, "person_info") return 0 migrated_count = 0 @@ -1123,20 +1211,24 @@ def _migrate_person_info(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "person_info") return migrated_count -def _migrate_expressions(connection: Connection) -> int: +def _migrate_expressions(context: MigrationExecutionContext) -> int: """迁移旧版 ``expression`` 到新版 ``expressions``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "expression") if legacy_table is None: + _complete_table_progress(context, "expressions") return 0 migrated_count = 0 @@ -1187,20 +1279,24 @@ def _migrate_expressions(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "expressions") return migrated_count -def _migrate_jargons(connection: Connection) -> int: +def _migrate_jargons(context: MigrationExecutionContext) -> int: """迁移旧版 ``jargon`` 到新版 ``jargons``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "jargon") if legacy_table is None: + _complete_table_progress(context, "jargons") return 0 migrated_count = 0 @@ -1259,20 +1355,24 @@ def _migrate_jargons(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "jargons") return migrated_count -def _migrate_chat_history(connection: Connection) -> int: +def _migrate_chat_history(context: MigrationExecutionContext) -> int: """迁移旧版 ``chat_history`` 到新版 ``chat_history``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "chat_history") if legacy_table is None: + _complete_table_progress(context, "chat_history") return 0 migrated_count = 0 @@ -1307,39 +1407,42 @@ def _migrate_chat_history(connection: Connection) -> int: ) for row in legacy_table.rows: session_id = _normalize_required_text(row.get("chat_id")) - if not session_id: - continue - connection.execute( - insert_sql, - { - "id": row.get("id"), - "session_id": session_id, - "start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True), - "end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True), - "query_count": _normalize_int(row.get("count"), default=0), - "query_forget_count": _normalize_int(row.get("forget_times"), default=0), - "original_messages": _normalize_required_text(row.get("original_text")), - "participants": _normalize_required_text(row.get("participants"), default="[]"), - "theme": _normalize_required_text(row.get("theme"), default=""), - "keywords": _normalize_required_text(row.get("keywords"), default="[]"), - "summary": _normalize_required_text(row.get("summary"), default=""), - }, - ) - migrated_count += 1 + if session_id: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "session_id": session_id, + "start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True), + "end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True), + "query_count": _normalize_int(row.get("count"), default=0), + "query_forget_count": _normalize_int(row.get("forget_times"), default=0), + "original_messages": _normalize_required_text(row.get("original_text")), + "participants": _normalize_required_text(row.get("participants"), default="[]"), + "theme": _normalize_required_text(row.get("theme"), default=""), + "keywords": _normalize_required_text(row.get("keywords"), default="[]"), + "summary": _normalize_required_text(row.get("summary"), default=""), + }, + ) + migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "chat_history") return migrated_count -def _migrate_thinking_questions(connection: Connection) -> int: +def _migrate_thinking_questions(context: MigrationExecutionContext) -> int: """迁移旧版 ``thinking_back`` 到新版 ``thinking_questions``。 Args: - connection: 当前数据库连接。 + context: 当前迁移步骤执行上下文。 Returns: int: 迁移成功的记录数。 """ + connection = context.connection legacy_table = _load_legacy_table_data(connection, "thinking_back") if legacy_table is None: + _complete_table_progress(context, "thinking_questions") return 0 migrated_count = 0 @@ -1381,4 +1484,6 @@ def _migrate_thinking_questions(connection: Connection) -> int: }, ) migrated_count += 1 + context.advance_progress(records=1) + _complete_table_progress(context, "thinking_questions") return migrated_count diff --git a/src/common/database/migrations/models.py b/src/common/database/migrations/models.py index bc8cf488..1bf39346 100644 --- a/src/common/database/migrations/models.py +++ b/src/common/database/migrations/models.py @@ -172,31 +172,51 @@ class MigrationExecutionContext: def start_progress( self, - total: int, + total_tables: int, + total_records: int, description: str = "总迁移进度", - unit_name: str = "表", + table_unit_name: str = "表", + record_unit_name: str = "记录", ) -> None: """启动当前迁移步骤的进度展示。 Args: - total: 当前步骤需要处理的总项目数。 + total_tables: 当前步骤需要处理的总表数。 + total_records: 当前步骤需要处理的总记录数。 description: 进度描述文本。 - unit_name: 进度单位名称。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 """ if self.progress_reporter is None: return - self.progress_reporter.start(total=total, description=description, unit_name=unit_name) + self.progress_reporter.start( + total_records=total_records, + total_tables=total_tables, + description=description, + table_unit_name=table_unit_name, + record_unit_name=record_unit_name, + ) - def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None: + def advance_progress( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: """推进当前迁移步骤的进度展示。 Args: - advance: 本次推进的步数。 + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 item_name: 当前完成的项目名称。 """ if self.progress_reporter is None: return - self.progress_reporter.advance(advance=advance, item_name=item_name) + self.progress_reporter.advance( + records=records, + completed_tables=completed_tables, + item_name=item_name, + ) MigrationHandler = Callable[[MigrationExecutionContext], None] diff --git a/src/common/database/migrations/progress.py b/src/common/database/migrations/progress.py index 4e358ed7..4aff8d38 100644 --- a/src/common/database/migrations/progress.py +++ b/src/common/database/migrations/progress.py @@ -38,9 +38,10 @@ class MigrationSummaryColumn(ProgressColumn): Returns: Text: 渲染后的摘要文本。 """ - display_total = task.fields.get("display_total", task.total) - total_text = "?" if display_total is None else str(int(display_total)) - completed_text = str(int(task.completed)) + completed_tables = int(task.fields.get("completed_tables", 0)) + display_table_total = task.fields.get("display_table_total") + total_text = "?" if display_table_total is None else str(int(display_table_total)) + completed_text = str(completed_tables) return Text(f"总迁移进度({completed_text}/{total_text})") @@ -56,7 +57,7 @@ class MigrationSpeedColumn(ProgressColumn): Returns: Text: 渲染后的速度文本。 """ - unit_name = str(task.fields.get("unit_name", "项")) + unit_name = str(task.fields.get("progress_unit_name", "项")) if task.speed is None or task.speed <= 0: return Text(f"-- {unit_name}/s") return Text(f"{task.speed:.2f} {unit_name}/s") @@ -126,24 +127,34 @@ class BaseMigrationProgressReporter(ABC): @abstractmethod def start( self, - total: int, + total_records: int, + total_tables: int, description: str = "总迁移进度", - unit_name: str = "表", + table_unit_name: str = "表", + record_unit_name: str = "记录", ) -> None: """启动一个新的迁移进度任务。 Args: - total: 任务总数。 + total_records: 任务记录总数。 + total_tables: 任务表总数。 description: 任务描述。 - unit_name: 进度单位名称。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 """ @abstractmethod - def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: """推进当前迁移进度任务。 Args: - advance: 本次推进的步数。 + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 item_name: 当前完成的项目名称。 """ @@ -151,32 +162,45 @@ class BaseMigrationProgressReporter(ABC): class NullMigrationProgressReporter(BaseMigrationProgressReporter): """不输出任何内容的空进度上报器。""" + def open(self) -> None: + """打开空进度上报器。""" + def close(self) -> None: """关闭空进度上报器。""" def start( self, - total: int, + total_records: int, + total_tables: int, description: str = "总迁移进度", - unit_name: str = "表", + table_unit_name: str = "表", + record_unit_name: str = "记录", ) -> None: """启动空进度任务。 Args: - total: 任务总数。 + total_records: 任务记录总数。 + total_tables: 任务表总数。 description: 任务描述。 - unit_name: 进度单位名称。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 """ - del total, description, unit_name + del total_records, total_tables, description, table_unit_name, record_unit_name - def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: """推进空进度任务。 Args: - advance: 本次推进的步数。 + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 item_name: 当前完成的项目名称。 """ - del advance, item_name + del records, completed_tables, item_name class RichMigrationProgressReporter(BaseMigrationProgressReporter): @@ -228,39 +252,62 @@ class RichMigrationProgressReporter(BaseMigrationProgressReporter): def start( self, - total: int, + total_records: int, + total_tables: int, description: str = "总迁移进度", - unit_name: str = "表", + table_unit_name: str = "表", + record_unit_name: str = "记录", ) -> None: """启动一个新的 ``rich`` 迁移进度任务。 Args: - total: 任务总数。 + total_records: 任务记录总数。 + total_tables: 任务表总数。 description: 任务描述。 - unit_name: 进度单位名称。 + table_unit_name: 表级进度单位名称。 + record_unit_name: 记录级进度单位名称。 """ if self._progress is None: self.open() assert self._progress is not None - effective_total = max(total, 1) + use_record_progress = total_records > 0 + effective_total = total_records if use_record_progress else total_tables + effective_total = max(effective_total, 1) + progress_unit_name = record_unit_name if use_record_progress else table_unit_name self._task_id = self._progress.add_task( description, total=effective_total, - display_total=total, - unit_name=unit_name, + completed_tables=0, + display_table_total=total_tables, + progress_unit_name=progress_unit_name, + use_record_progress=use_record_progress, ) - def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + def advance( + self, + records: int = 0, + completed_tables: int = 0, + item_name: Optional[str] = None, + ) -> None: """推进当前 ``rich`` 迁移进度任务。 Args: - advance: 本次推进的步数。 + records: 本次推进的记录数。 + completed_tables: 本次完成的表数。 item_name: 当前完成的项目名称。 """ del item_name if self._progress is None or self._task_id is None: return - self._progress.update(self._task_id, advance=advance) + task = self._progress.tasks[self._task_id] + use_record_progress = bool(task.fields.get("use_record_progress", False)) + progress_advance = records if use_record_progress else completed_tables + updated_completed_tables = int(task.fields.get("completed_tables", 0)) + completed_tables + self._progress.update( + self._task_id, + advance=progress_advance, + completed_tables=updated_completed_tables, + ) def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter: