feat(migration): enhance migration progress reporting with detailed record and table tracking

This commit is contained in:
DrSmoothl
2026-03-31 09:29:58 +08:00
parent c2c992ff01
commit 5d410171d2
4 changed files with 422 additions and 238 deletions

View File

@@ -66,7 +66,7 @@ def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
_rename_legacy_v1_tables(context.connection, snapshot)
SQLModel.metadata.create_all(context.connection)
table_migration_jobs: List[Tuple[str, Callable[[Connection], int]]] = [
table_migration_jobs: List[Tuple[str, Callable[[MigrationExecutionContext], int]]] = [
("chat_sessions", _migrate_chat_sessions),
("llm_usage", _migrate_model_usage),
("images", _migrate_images),
@@ -81,10 +81,16 @@ def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None:
("thinking_questions", _migrate_thinking_questions),
]
migrated_counts: Dict[str, int] = {}
context.start_progress(total=len(table_migration_jobs), description="总迁移进度", unit_name="")
total_record_count = _estimate_total_record_count(context.connection)
context.start_progress(
total_tables=len(table_migration_jobs),
total_records=total_record_count,
description="总迁移进度",
table_unit_name="",
record_unit_name="记录",
)
for table_name, migration_handler in table_migration_jobs:
migrated_counts[table_name] = migration_handler(context.connection)
context.advance_progress(item_name=table_name)
migrated_counts[table_name] = migration_handler(context)
summary_text = ", ".join(f"{table_name}={count}" for table_name, count in migrated_counts.items())
logger.info(f"旧版数据库迁移完成: {summary_text}")
@@ -535,17 +541,77 @@ def _deduce_image_type_name(value: Any) -> str:
return "IMAGE"
def _migrate_chat_sessions(connection: Connection) -> int:
"""迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``
def _count_legacy_table_rows(connection: Connection, original_table_name: str) -> int:
"""统计单张旧版备份表中的记录总数
Args:
connection: 当前数据库连接。
original_table_name: 旧版原始表名。
Returns:
int: 备份表中的记录数;若表不存在则返回 ``0``。
"""
backup_table_name = _legacy_backup_table_name(original_table_name)
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, backup_table_name):
return 0
row = connection.execute(
text(f"SELECT COUNT(*) FROM {_quote_identifier(backup_table_name)}")
).first()
if row is None:
return 0
return _normalize_int(row[0], default=0)
def _estimate_total_record_count(connection: Connection) -> int:
"""估算旧版迁移步骤需要处理的总记录数。
Args:
connection: 当前数据库连接。
Returns:
int: 本次迁移预计处理的总记录数。
"""
return (
_count_legacy_table_rows(connection, "chat_streams")
+ _count_legacy_table_rows(connection, "llm_usage")
+ _count_legacy_table_rows(connection, "emoji")
+ _count_legacy_table_rows(connection, "images")
+ _count_legacy_table_rows(connection, "messages")
+ _count_legacy_table_rows(connection, "action_records")
+ _count_legacy_table_rows(connection, "action_records")
+ _count_legacy_table_rows(connection, "online_time")
+ _count_legacy_table_rows(connection, "person_info")
+ _count_legacy_table_rows(connection, "expression")
+ _count_legacy_table_rows(connection, "jargon")
+ _count_legacy_table_rows(connection, "chat_history")
+ _count_legacy_table_rows(connection, "thinking_back")
)
def _complete_table_progress(context: MigrationExecutionContext, table_name: str) -> None:
"""标记单张表的迁移已经完成。
Args:
context: 当前迁移步骤执行上下文。
table_name: 已完成迁移的表名。
"""
context.advance_progress(completed_tables=1, item_name=table_name)
def _migrate_chat_sessions(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``。
Args:
context: 当前迁移步骤执行上下文。
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "chat_streams")
if legacy_table is None:
_complete_table_progress(context, "chat_sessions")
return 0
migrated_count = 0
@@ -570,34 +636,37 @@ def _migrate_chat_sessions(connection: Connection) -> int:
)
for row in legacy_table.rows:
session_id = _normalize_required_text(row.get("stream_id"))
if not session_id:
continue
connection.execute(
insert_sql,
{
"session_id": session_id,
"created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True),
"last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True),
"user_id": _normalize_optional_text(row.get("user_id")),
"group_id": _normalize_optional_text(row.get("group_id")),
"platform": _normalize_required_text(row.get("platform"), default="unknown"),
},
)
migrated_count += 1
if session_id:
connection.execute(
insert_sql,
{
"session_id": session_id,
"created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True),
"last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True),
"user_id": _normalize_optional_text(row.get("user_id")),
"group_id": _normalize_optional_text(row.get("group_id")),
"platform": _normalize_required_text(row.get("platform"), default="unknown"),
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "chat_sessions")
return migrated_count
def _migrate_model_usage(connection: Connection) -> int:
def _migrate_model_usage(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``llm_usage`` 到新版 ``llm_usage``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "llm_usage")
if legacy_table is None:
_complete_table_progress(context, "llm_usage")
return 0
migrated_count = 0
@@ -654,18 +723,21 @@ def _migrate_model_usage(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "llm_usage")
return migrated_count
def _migrate_images(connection: Connection) -> int:
def _migrate_images(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``emoji`` 与 ``images`` 到新版 ``images``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
migrated_count = 0
existing_keys: Set[Tuple[str, str, str]] = set()
existing_rows = connection.execute(
@@ -719,28 +791,28 @@ def _migrate_images(connection: Connection) -> int:
full_path = _normalize_required_text(row.get("full_path"))
image_hash = _normalize_required_text(row.get("emoji_hash"))
dedupe_key = (full_path, image_hash, "EMOJI")
if not full_path or dedupe_key in existing_keys:
continue
connection.execute(
insert_sql,
{
"image_hash": image_hash,
"description": _normalize_required_text(row.get("description")),
"full_path": full_path,
"image_type": "EMOJI",
"emotion": _normalize_optional_text(row.get("emotion")),
"query_count": _normalize_int(row.get("query_count"), default=0),
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
"no_file_flag": False,
"record_time": _coerce_datetime(row.get("record_time"), fallback_now=True),
"register_time": _coerce_datetime(row.get("register_time")),
"last_used_time": _coerce_datetime(row.get("last_used_time")),
"vlm_processed": False,
},
)
existing_keys.add(dedupe_key)
migrated_count += 1
if full_path and dedupe_key not in existing_keys:
connection.execute(
insert_sql,
{
"image_hash": image_hash,
"description": _normalize_required_text(row.get("description")),
"full_path": full_path,
"image_type": "EMOJI",
"emotion": _normalize_optional_text(row.get("emotion")),
"query_count": _normalize_int(row.get("query_count"), default=0),
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
"no_file_flag": False,
"record_time": _coerce_datetime(row.get("record_time"), fallback_now=True),
"register_time": _coerce_datetime(row.get("register_time")),
"last_used_time": _coerce_datetime(row.get("last_used_time")),
"vlm_processed": False,
},
)
existing_keys.add(dedupe_key)
migrated_count += 1
context.advance_progress(records=1)
legacy_images_table = _load_legacy_table_data(connection, "images")
if legacy_images_table is not None:
@@ -749,43 +821,46 @@ def _migrate_images(connection: Connection) -> int:
image_hash = _normalize_required_text(row.get("emoji_hash"))
image_type = _deduce_image_type_name(row.get("type"))
dedupe_key = (full_path, image_hash, image_type)
if not full_path or dedupe_key in existing_keys:
continue
connection.execute(
insert_sql,
{
"image_hash": image_hash,
"description": _normalize_required_text(row.get("description")),
"full_path": full_path,
"image_type": image_type,
"emotion": None,
"query_count": _normalize_int(row.get("count"), default=0),
"is_registered": False,
"is_banned": False,
"no_file_flag": False,
"record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True),
"register_time": None,
"last_used_time": None,
"vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False),
},
)
existing_keys.add(dedupe_key)
migrated_count += 1
if full_path and dedupe_key not in existing_keys:
connection.execute(
insert_sql,
{
"image_hash": image_hash,
"description": _normalize_required_text(row.get("description")),
"full_path": full_path,
"image_type": image_type,
"emotion": None,
"query_count": _normalize_int(row.get("count"), default=0),
"is_registered": False,
"is_banned": False,
"no_file_flag": False,
"record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True),
"register_time": None,
"last_used_time": None,
"vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False),
},
)
existing_keys.add(dedupe_key)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "images")
return migrated_count
def _migrate_messages(connection: Connection) -> int:
def _migrate_messages(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``messages`` 到新版 ``mai_messages``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "messages")
if legacy_table is None:
_complete_table_progress(context, "mai_messages")
return 0
migrated_count = 0
@@ -840,62 +915,65 @@ def _migrate_messages(connection: Connection) -> int:
)
for row in legacy_table.rows:
session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id"))
if not session_id:
continue
processed_plain_text = _normalize_optional_text(row.get("processed_plain_text"))
display_message = _normalize_optional_text(row.get("display_message"))
connection.execute(
insert_sql,
{
"id": row.get("id"),
"message_id": _normalize_required_text(row.get("message_id"), default=""),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"platform": _normalize_required_text(
row.get("chat_info_platform") or row.get("user_platform"),
default="unknown",
),
"user_id": _normalize_required_text(
row.get("user_id") or row.get("chat_info_user_id"),
default="",
),
"user_nickname": _normalize_required_text(
row.get("user_nickname") or row.get("chat_info_user_nickname"),
default="",
),
"user_cardname": _normalize_optional_text(
row.get("user_cardname") or row.get("chat_info_user_cardname")
),
"group_id": _normalize_optional_text(row.get("chat_info_group_id")),
"group_name": _normalize_optional_text(row.get("chat_info_group_name")),
"is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False),
"is_at": _normalize_bool(row.get("is_at"), default=False),
"session_id": session_id,
"reply_to": _normalize_optional_text(row.get("reply_to")),
"is_emoji": _normalize_bool(row.get("is_emoji"), default=False),
"is_picture": _normalize_bool(row.get("is_picid"), default=False),
"is_command": _normalize_bool(row.get("is_command"), default=False),
"is_notify": _normalize_bool(row.get("is_notify"), default=False),
"raw_content": _build_message_raw_content(processed_plain_text, display_message),
"processed_plain_text": processed_plain_text,
"display_message": display_message,
"additional_config": _build_legacy_message_additional_config(row),
},
)
migrated_count += 1
if session_id:
processed_plain_text = _normalize_optional_text(row.get("processed_plain_text"))
display_message = _normalize_optional_text(row.get("display_message"))
connection.execute(
insert_sql,
{
"id": row.get("id"),
"message_id": _normalize_required_text(row.get("message_id"), default=""),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"platform": _normalize_required_text(
row.get("chat_info_platform") or row.get("user_platform"),
default="unknown",
),
"user_id": _normalize_required_text(
row.get("user_id") or row.get("chat_info_user_id"),
default="",
),
"user_nickname": _normalize_required_text(
row.get("user_nickname") or row.get("chat_info_user_nickname"),
default="",
),
"user_cardname": _normalize_optional_text(
row.get("user_cardname") or row.get("chat_info_user_cardname")
),
"group_id": _normalize_optional_text(row.get("chat_info_group_id")),
"group_name": _normalize_optional_text(row.get("chat_info_group_name")),
"is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False),
"is_at": _normalize_bool(row.get("is_at"), default=False),
"session_id": session_id,
"reply_to": _normalize_optional_text(row.get("reply_to")),
"is_emoji": _normalize_bool(row.get("is_emoji"), default=False),
"is_picture": _normalize_bool(row.get("is_picid"), default=False),
"is_command": _normalize_bool(row.get("is_command"), default=False),
"is_notify": _normalize_bool(row.get("is_notify"), default=False),
"raw_content": _build_message_raw_content(processed_plain_text, display_message),
"processed_plain_text": processed_plain_text,
"display_message": display_message,
"additional_config": _build_legacy_message_additional_config(row),
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "mai_messages")
return migrated_count
def _migrate_action_records(connection: Connection) -> int:
def _migrate_action_records(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``action_records`` 到新版 ``action_records``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "action_records")
if legacy_table is None:
_complete_table_progress(context, "action_records")
return 0
migrated_count = 0
@@ -926,37 +1004,40 @@ def _migrate_action_records(connection: Connection) -> int:
)
for row in legacy_table.rows:
session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id"))
if not session_id:
continue
connection.execute(
insert_sql,
{
"id": row.get("id"),
"action_id": _normalize_required_text(row.get("action_id")),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"session_id": session_id,
"action_name": _normalize_required_text(row.get("action_name"), default="unknown"),
"action_reasoning": _normalize_optional_text(row.get("action_reasoning")),
"action_data": _normalize_optional_text(row.get("action_data")),
"action_builtin_prompt": None,
"action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")),
},
)
migrated_count += 1
if session_id:
connection.execute(
insert_sql,
{
"id": row.get("id"),
"action_id": _normalize_required_text(row.get("action_id")),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"session_id": session_id,
"action_name": _normalize_required_text(row.get("action_name"), default="unknown"),
"action_reasoning": _normalize_optional_text(row.get("action_reasoning")),
"action_data": _normalize_optional_text(row.get("action_data")),
"action_builtin_prompt": None,
"action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")),
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "action_records")
return migrated_count
def _migrate_tool_records(connection: Connection) -> int:
def _migrate_tool_records(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``action_records`` 到新版 ``tool_records``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "action_records")
if legacy_table is None:
_complete_table_progress(context, "tool_records")
return 0
migrated_count = 0
@@ -987,37 +1068,40 @@ def _migrate_tool_records(connection: Connection) -> int:
)
for row in legacy_table.rows:
session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id"))
if not session_id:
continue
connection.execute(
insert_sql,
{
"id": row.get("id"),
"tool_id": _normalize_required_text(row.get("action_id")),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"session_id": session_id,
"tool_name": _normalize_required_text(row.get("action_name"), default="unknown"),
"tool_reasoning": _normalize_optional_text(row.get("action_reasoning")),
"tool_data": _normalize_optional_text(row.get("action_data")),
"tool_builtin_prompt": None,
"tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")),
},
)
migrated_count += 1
if session_id:
connection.execute(
insert_sql,
{
"id": row.get("id"),
"tool_id": _normalize_required_text(row.get("action_id")),
"timestamp": _coerce_datetime(row.get("time"), fallback_now=True),
"session_id": session_id,
"tool_name": _normalize_required_text(row.get("action_name"), default="unknown"),
"tool_reasoning": _normalize_optional_text(row.get("action_reasoning")),
"tool_data": _normalize_optional_text(row.get("action_data")),
"tool_builtin_prompt": None,
"tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")),
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "tool_records")
return migrated_count
def _migrate_online_time(connection: Connection) -> int:
def _migrate_online_time(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``online_time`` 到新版 ``online_time``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "online_time")
if legacy_table is None:
_complete_table_progress(context, "online_time")
return 0
migrated_count = 0
@@ -1050,20 +1134,24 @@ def _migrate_online_time(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "online_time")
return migrated_count
def _migrate_person_info(connection: Connection) -> int:
def _migrate_person_info(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``person_info`` 到新版 ``person_info``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "person_info")
if legacy_table is None:
_complete_table_progress(context, "person_info")
return 0
migrated_count = 0
@@ -1123,20 +1211,24 @@ def _migrate_person_info(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "person_info")
return migrated_count
def _migrate_expressions(connection: Connection) -> int:
def _migrate_expressions(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``expression`` 到新版 ``expressions``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "expression")
if legacy_table is None:
_complete_table_progress(context, "expressions")
return 0
migrated_count = 0
@@ -1187,20 +1279,24 @@ def _migrate_expressions(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "expressions")
return migrated_count
def _migrate_jargons(connection: Connection) -> int:
def _migrate_jargons(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``jargon`` 到新版 ``jargons``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "jargon")
if legacy_table is None:
_complete_table_progress(context, "jargons")
return 0
migrated_count = 0
@@ -1259,20 +1355,24 @@ def _migrate_jargons(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "jargons")
return migrated_count
def _migrate_chat_history(connection: Connection) -> int:
def _migrate_chat_history(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``chat_history`` 到新版 ``chat_history``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "chat_history")
if legacy_table is None:
_complete_table_progress(context, "chat_history")
return 0
migrated_count = 0
@@ -1307,39 +1407,42 @@ def _migrate_chat_history(connection: Connection) -> int:
)
for row in legacy_table.rows:
session_id = _normalize_required_text(row.get("chat_id"))
if not session_id:
continue
connection.execute(
insert_sql,
{
"id": row.get("id"),
"session_id": session_id,
"start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True),
"end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True),
"query_count": _normalize_int(row.get("count"), default=0),
"query_forget_count": _normalize_int(row.get("forget_times"), default=0),
"original_messages": _normalize_required_text(row.get("original_text")),
"participants": _normalize_required_text(row.get("participants"), default="[]"),
"theme": _normalize_required_text(row.get("theme"), default=""),
"keywords": _normalize_required_text(row.get("keywords"), default="[]"),
"summary": _normalize_required_text(row.get("summary"), default=""),
},
)
migrated_count += 1
if session_id:
connection.execute(
insert_sql,
{
"id": row.get("id"),
"session_id": session_id,
"start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True),
"end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True),
"query_count": _normalize_int(row.get("count"), default=0),
"query_forget_count": _normalize_int(row.get("forget_times"), default=0),
"original_messages": _normalize_required_text(row.get("original_text")),
"participants": _normalize_required_text(row.get("participants"), default="[]"),
"theme": _normalize_required_text(row.get("theme"), default=""),
"keywords": _normalize_required_text(row.get("keywords"), default="[]"),
"summary": _normalize_required_text(row.get("summary"), default=""),
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "chat_history")
return migrated_count
def _migrate_thinking_questions(connection: Connection) -> int:
def _migrate_thinking_questions(context: MigrationExecutionContext) -> int:
"""迁移旧版 ``thinking_back`` 到新版 ``thinking_questions``。
Args:
connection: 当前数据库连接
context: 当前迁移步骤执行上下文
Returns:
int: 迁移成功的记录数。
"""
connection = context.connection
legacy_table = _load_legacy_table_data(connection, "thinking_back")
if legacy_table is None:
_complete_table_progress(context, "thinking_questions")
return 0
migrated_count = 0
@@ -1381,4 +1484,6 @@ def _migrate_thinking_questions(connection: Connection) -> int:
},
)
migrated_count += 1
context.advance_progress(records=1)
_complete_table_progress(context, "thinking_questions")
return migrated_count

View File

@@ -172,31 +172,51 @@ class MigrationExecutionContext:
def start_progress(
self,
total: int,
total_tables: int,
total_records: int,
description: str = "总迁移进度",
unit_name: str = "",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""启动当前迁移步骤的进度展示。
Args:
total: 当前步骤需要处理的总项目数。
total_tables: 当前步骤需要处理的总数。
total_records: 当前步骤需要处理的总记录数。
description: 进度描述文本。
unit_name: 进度单位名称。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
if self.progress_reporter is None:
return
self.progress_reporter.start(total=total, description=description, unit_name=unit_name)
self.progress_reporter.start(
total_records=total_records,
total_tables=total_tables,
description=description,
table_unit_name=table_unit_name,
record_unit_name=record_unit_name,
)
def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None:
def advance_progress(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""推进当前迁移步骤的进度展示。
Args:
advance: 本次推进的数。
records: 本次推进的记录数。
completed_tables: 本次完成的表数。
item_name: 当前完成的项目名称。
"""
if self.progress_reporter is None:
return
self.progress_reporter.advance(advance=advance, item_name=item_name)
self.progress_reporter.advance(
records=records,
completed_tables=completed_tables,
item_name=item_name,
)
MigrationHandler = Callable[[MigrationExecutionContext], None]

View File

@@ -38,9 +38,10 @@ class MigrationSummaryColumn(ProgressColumn):
Returns:
Text: 渲染后的摘要文本。
"""
display_total = task.fields.get("display_total", task.total)
total_text = "?" if display_total is None else str(int(display_total))
completed_text = str(int(task.completed))
completed_tables = int(task.fields.get("completed_tables", 0))
display_table_total = task.fields.get("display_table_total")
total_text = "?" if display_table_total is None else str(int(display_table_total))
completed_text = str(completed_tables)
return Text(f"总迁移进度({completed_text}/{total_text}")
@@ -56,7 +57,7 @@ class MigrationSpeedColumn(ProgressColumn):
Returns:
Text: 渲染后的速度文本。
"""
unit_name = str(task.fields.get("unit_name", ""))
unit_name = str(task.fields.get("progress_unit_name", ""))
if task.speed is None or task.speed <= 0:
return Text(f"-- {unit_name}/s")
return Text(f"{task.speed:.2f} {unit_name}/s")
@@ -126,24 +127,34 @@ class BaseMigrationProgressReporter(ABC):
@abstractmethod
def start(
self,
total: int,
total_records: int,
total_tables: int,
description: str = "总迁移进度",
unit_name: str = "",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""启动一个新的迁移进度任务。
Args:
total: 任务总数。
total_records: 任务记录总数。
total_tables: 任务表总数。
description: 任务描述。
unit_name: 进度单位名称。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
@abstractmethod
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
def advance(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""推进当前迁移进度任务。
Args:
advance: 本次推进的数。
records: 本次推进的记录数。
completed_tables: 本次完成的表数。
item_name: 当前完成的项目名称。
"""
@@ -151,32 +162,45 @@ class BaseMigrationProgressReporter(ABC):
class NullMigrationProgressReporter(BaseMigrationProgressReporter):
"""不输出任何内容的空进度上报器。"""
def open(self) -> None:
"""打开空进度上报器。"""
def close(self) -> None:
"""关闭空进度上报器。"""
def start(
self,
total: int,
total_records: int,
total_tables: int,
description: str = "总迁移进度",
unit_name: str = "",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""启动空进度任务。
Args:
total: 任务总数。
total_records: 任务记录总数。
total_tables: 任务表总数。
description: 任务描述。
unit_name: 进度单位名称。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
del total, description, unit_name
del total_records, total_tables, description, table_unit_name, record_unit_name
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
def advance(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""推进空进度任务。
Args:
advance: 本次推进的数。
records: 本次推进的记录数。
completed_tables: 本次完成的表数。
item_name: 当前完成的项目名称。
"""
del advance, item_name
del records, completed_tables, item_name
class RichMigrationProgressReporter(BaseMigrationProgressReporter):
@@ -228,39 +252,62 @@ class RichMigrationProgressReporter(BaseMigrationProgressReporter):
def start(
self,
total: int,
total_records: int,
total_tables: int,
description: str = "总迁移进度",
unit_name: str = "",
table_unit_name: str = "",
record_unit_name: str = "记录",
) -> None:
"""启动一个新的 ``rich`` 迁移进度任务。
Args:
total: 任务总数。
total_records: 任务记录总数。
total_tables: 任务表总数。
description: 任务描述。
unit_name: 进度单位名称。
table_unit_name: 表级进度单位名称。
record_unit_name: 记录级进度单位名称。
"""
if self._progress is None:
self.open()
assert self._progress is not None
effective_total = max(total, 1)
use_record_progress = total_records > 0
effective_total = total_records if use_record_progress else total_tables
effective_total = max(effective_total, 1)
progress_unit_name = record_unit_name if use_record_progress else table_unit_name
self._task_id = self._progress.add_task(
description,
total=effective_total,
display_total=total,
unit_name=unit_name,
completed_tables=0,
display_table_total=total_tables,
progress_unit_name=progress_unit_name,
use_record_progress=use_record_progress,
)
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
def advance(
self,
records: int = 0,
completed_tables: int = 0,
item_name: Optional[str] = None,
) -> None:
"""推进当前 ``rich`` 迁移进度任务。
Args:
advance: 本次推进的数。
records: 本次推进的记录数。
completed_tables: 本次完成的表数。
item_name: 当前完成的项目名称。
"""
del item_name
if self._progress is None or self._task_id is None:
return
self._progress.update(self._task_id, advance=advance)
task = self._progress.tasks[self._task_id]
use_record_progress = bool(task.fields.get("use_record_progress", False))
progress_advance = records if use_record_progress else completed_tables
updated_completed_tables = int(task.fields.get("completed_tables", 0)) + completed_tables
self._progress.update(
self._task_id,
advance=progress_advance,
completed_tables=updated_completed_tables,
)
def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter: