536 lines
20 KiB
Python
536 lines
20 KiB
Python
from argparse import ArgumentParser, Namespace
|
||
from collections.abc import Iterable
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from sys import path as sys_path
|
||
from typing import Any, Optional
|
||
|
||
import json
|
||
import sqlite3
|
||
|
||
from sqlalchemy import text
|
||
from sqlmodel import Session, SQLModel, create_engine, delete
|
||
|
||
ROOT_PATH = Path(__file__).resolve().parent.parent
|
||
if str(ROOT_PATH) not in sys_path:
|
||
sys_path.insert(0, str(ROOT_PATH))
|
||
|
||
from src.common.database.database_model import Expression, Jargon, ModifiedBy
|
||
|
||
|
||
def build_argument_parser() -> ArgumentParser:
|
||
"""构建命令行参数解析器。"""
|
||
parser = ArgumentParser(
|
||
description="将旧版 expression/jargon 数据迁移到新版 expressions/jargons 数据库。"
|
||
)
|
||
parser.add_argument("--source-db", dest="source_db", help="旧版 SQLite 数据库路径")
|
||
parser.add_argument("--target-db", dest="target_db", help="新版 SQLite 数据库路径")
|
||
parser.add_argument(
|
||
"--clear-target",
|
||
dest="clear_target",
|
||
action="store_true",
|
||
help="迁移前清空目标库中的 expressions 和 jargons 表",
|
||
)
|
||
return parser
|
||
|
||
|
||
def prompt_path(prompt_text: str, current_value: Optional[str] = None) -> Path:
|
||
"""读取数据库路径输入。"""
|
||
while True:
|
||
suffix = f" [{current_value}]" if current_value else ""
|
||
raw_text = input(f"{prompt_text}{suffix}: ").strip()
|
||
value = raw_text or current_value or ""
|
||
if not value:
|
||
print("路径不能为空,请重新输入。")
|
||
continue
|
||
return Path(value).expanduser().resolve()
|
||
|
||
|
||
def prompt_yes_no(prompt_text: str, default: bool = False) -> bool:
|
||
"""读取是否确认输入。"""
|
||
default_hint = "Y/n" if default else "y/N"
|
||
raw_text = input(f"{prompt_text} [{default_hint}]: ").strip().lower()
|
||
if not raw_text:
|
||
return default
|
||
return raw_text in {"y", "yes"}
|
||
|
||
|
||
def ensure_sqlite_file(path: Path, should_exist: bool) -> None:
|
||
"""校验 SQLite 文件路径。"""
|
||
if should_exist and not path.is_file():
|
||
raise FileNotFoundError(f"数据库文件不存在:{path}")
|
||
if not should_exist:
|
||
path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
def connect_sqlite(path: Path) -> sqlite3.Connection:
|
||
"""创建 SQLite 连接。"""
|
||
connection = sqlite3.connect(path)
|
||
connection.row_factory = sqlite3.Row
|
||
return connection
|
||
|
||
|
||
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:
|
||
"""检查表是否存在。"""
|
||
result = connection.execute(
|
||
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? LIMIT 1",
|
||
(table_name,),
|
||
).fetchone()
|
||
return result is not None
|
||
|
||
|
||
def resolve_source_table_name(connection: sqlite3.Connection, candidates: list[str]) -> str:
|
||
"""从候选表名中解析实际存在的表名。"""
|
||
for table_name in candidates:
|
||
if table_exists(connection, table_name):
|
||
return table_name
|
||
raise ValueError(f"未找到候选表:{', '.join(candidates)}")
|
||
|
||
|
||
def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[str]:
|
||
"""获取表字段名集合。"""
|
||
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
|
||
return {str(row["name"]) for row in rows}
|
||
|
||
|
||
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
|
||
"""获取表字段是否允许 NULL 的映射。"""
|
||
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
|
||
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
|
||
|
||
|
||
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
|
||
"""读取整张表的数据。"""
|
||
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
|
||
|
||
|
||
def normalize_optional_text(raw_value: Any) -> Optional[str]:
|
||
"""标准化可空文本字段。"""
|
||
if raw_value is None:
|
||
return None
|
||
return str(raw_value)
|
||
|
||
|
||
def ensure_nullable_compatibility(
|
||
table_name: str,
|
||
column_name: str,
|
||
row_id: Any,
|
||
value: Any,
|
||
nullable_map: dict[str, bool],
|
||
) -> None:
|
||
"""检查待迁移值是否与目标表可空约束兼容。"""
|
||
if value is None and not nullable_map.get(column_name, True):
|
||
raise ValueError(
|
||
f"目标表 {table_name}.{column_name} 不允许 NULL,但源记录 id={row_id} 的该字段为 NULL。"
|
||
)
|
||
|
||
|
||
def normalize_string_list(raw_value: Any) -> list[str]:
|
||
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
|
||
if raw_value is None:
|
||
return []
|
||
if isinstance(raw_value, list):
|
||
return [str(item).strip() for item in raw_value if str(item).strip()]
|
||
if isinstance(raw_value, str):
|
||
raw_text = raw_value.strip()
|
||
if not raw_text:
|
||
return []
|
||
try:
|
||
parsed = json.loads(raw_text)
|
||
except json.JSONDecodeError:
|
||
return [raw_text]
|
||
if isinstance(parsed, list):
|
||
return [str(item).strip() for item in parsed if str(item).strip()]
|
||
if isinstance(parsed, str):
|
||
parsed_text = parsed.strip()
|
||
return [parsed_text] if parsed_text else []
|
||
if parsed is None:
|
||
return []
|
||
return [str(parsed).strip()]
|
||
return [str(raw_value).strip()]
|
||
|
||
|
||
def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
|
||
"""标准化审核来源字段。"""
|
||
if raw_value is None:
|
||
return None
|
||
|
||
normalized_raw_value = raw_value
|
||
if isinstance(raw_value, str):
|
||
raw_text = raw_value.strip()
|
||
if raw_text.startswith('"') and raw_text.endswith('"'):
|
||
try:
|
||
normalized_raw_value = json.loads(raw_text)
|
||
except json.JSONDecodeError:
|
||
normalized_raw_value = raw_text
|
||
else:
|
||
normalized_raw_value = raw_text
|
||
|
||
value = str(normalized_raw_value).strip().lower()
|
||
if value in {"", "none", "null"}:
|
||
return None
|
||
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
|
||
return ModifiedBy.AI
|
||
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
|
||
return ModifiedBy.USER
|
||
return None
|
||
|
||
|
||
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
|
||
"""解析可空布尔值,兼容整数和字符串。"""
|
||
if raw_value is None:
|
||
return None
|
||
if isinstance(raw_value, bool):
|
||
return raw_value
|
||
if isinstance(raw_value, int):
|
||
return bool(raw_value)
|
||
if isinstance(raw_value, float):
|
||
return bool(int(raw_value))
|
||
|
||
value = str(raw_value).strip().lower()
|
||
if value in {"", "none", "null"}:
|
||
return None
|
||
if value in {"1", "true", "t", "yes", "y"}:
|
||
return True
|
||
if value in {"0", "false", "f", "no", "n"}:
|
||
return False
|
||
raise ValueError(f"无法解析布尔值:{raw_value}")
|
||
|
||
|
||
def parse_bool(raw_value: Any, default: bool = False) -> bool:
|
||
"""解析非空布尔值。"""
|
||
parsed = parse_optional_bool(raw_value)
|
||
return default if parsed is None else parsed
|
||
|
||
|
||
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
|
||
"""将旧库中的 Unix 时间戳转换为 datetime。"""
|
||
if raw_value is None or raw_value == "":
|
||
return datetime.now() if fallback_now else None
|
||
if isinstance(raw_value, datetime):
|
||
return raw_value
|
||
try:
|
||
return datetime.fromtimestamp(float(raw_value))
|
||
except (TypeError, ValueError, OSError, OverflowError):
|
||
return datetime.now() if fallback_now else None
|
||
|
||
|
||
def build_session_id_dict(raw_chat_id: Any, fallback_count: int) -> str:
|
||
"""将旧版 jargon.chat_id 转换为新版 session_id_dict。"""
|
||
if raw_chat_id is None:
|
||
return json.dumps({}, ensure_ascii=False)
|
||
|
||
if isinstance(raw_chat_id, str):
|
||
raw_text = raw_chat_id.strip()
|
||
else:
|
||
raw_text = str(raw_chat_id).strip()
|
||
|
||
if not raw_text:
|
||
return json.dumps({}, ensure_ascii=False)
|
||
|
||
try:
|
||
parsed = json.loads(raw_text)
|
||
except json.JSONDecodeError:
|
||
return json.dumps({raw_text: max(fallback_count, 1)}, ensure_ascii=False)
|
||
|
||
if isinstance(parsed, str):
|
||
parsed_text = parsed.strip()
|
||
session_counts = {parsed_text: max(fallback_count, 1)} if parsed_text else {}
|
||
return json.dumps(session_counts, ensure_ascii=False)
|
||
|
||
if not isinstance(parsed, list):
|
||
return json.dumps({}, ensure_ascii=False)
|
||
|
||
session_counts: dict[str, int] = {}
|
||
for item in parsed:
|
||
if not isinstance(item, list) or not item:
|
||
continue
|
||
session_id = str(item[0]).strip()
|
||
if not session_id:
|
||
continue
|
||
item_count = 1
|
||
if len(item) > 1:
|
||
try:
|
||
item_count = int(item[1])
|
||
except (TypeError, ValueError):
|
||
item_count = 1
|
||
session_counts[session_id] = max(item_count, 1)
|
||
|
||
return json.dumps(session_counts, ensure_ascii=False)
|
||
|
||
|
||
def create_target_engine(target_db_path: Path):
|
||
"""创建目标数据库引擎。"""
|
||
return create_engine(
|
||
f"sqlite:///{target_db_path.as_posix()}",
|
||
echo=False,
|
||
connect_args={"check_same_thread": False},
|
||
)
|
||
|
||
|
||
def clear_target_tables(session: Session) -> None:
|
||
"""清空目标表。"""
|
||
session.exec(delete(Expression))
|
||
session.exec(delete(Jargon))
|
||
|
||
|
||
def migrate_expressions(
|
||
old_rows: Iterable[sqlite3.Row],
|
||
target_session: Session,
|
||
expression_columns: set[str],
|
||
) -> int:
|
||
"""迁移 expression 数据。"""
|
||
migrated_count = 0
|
||
modified_by_ai_count = 0
|
||
modified_by_user_count = 0
|
||
modified_by_null_count = 0
|
||
unknown_modified_by_values: dict[str, int] = {}
|
||
for row in old_rows:
|
||
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
|
||
last_active_time = timestamp_to_datetime(
|
||
row["last_active_time"] if "last_active_time" in expression_columns else None,
|
||
True,
|
||
)
|
||
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
|
||
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
|
||
modified_by = normalize_modified_by(raw_modified_by)
|
||
if modified_by == ModifiedBy.AI:
|
||
modified_by_ai_count += 1
|
||
elif modified_by == ModifiedBy.USER:
|
||
modified_by_user_count += 1
|
||
else:
|
||
modified_by_null_count += 1
|
||
if raw_modified_by not in (None, "", "null", "NULL", "None"):
|
||
unknown_key = str(raw_modified_by)
|
||
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
|
||
|
||
target_session.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO expressions (
|
||
id,
|
||
situation,
|
||
style,
|
||
content_list,
|
||
count,
|
||
last_active_time,
|
||
create_time,
|
||
session_id,
|
||
checked,
|
||
rejected,
|
||
modified_by
|
||
) VALUES (
|
||
:id,
|
||
:situation,
|
||
:style,
|
||
:content_list,
|
||
:count,
|
||
:last_active_time,
|
||
:create_time,
|
||
:session_id,
|
||
:checked,
|
||
:rejected,
|
||
:modified_by
|
||
)
|
||
"""
|
||
),
|
||
{
|
||
"id": int(row["id"]) if row["id"] is not None else None,
|
||
"situation": str(row["situation"]).strip(),
|
||
"style": str(row["style"]).strip(),
|
||
"content_list": json.dumps(content_list, ensure_ascii=False),
|
||
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
|
||
"last_active_time": last_active_time or datetime.now(),
|
||
"create_time": create_time or datetime.now(),
|
||
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
|
||
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
|
||
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
|
||
"modified_by": modified_by.name if modified_by is not None else None,
|
||
},
|
||
)
|
||
migrated_count += 1
|
||
|
||
print(
|
||
"Expression modified_by 迁移统计:"
|
||
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
|
||
)
|
||
if unknown_modified_by_values:
|
||
preview_items = list(unknown_modified_by_values.items())[:10]
|
||
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
|
||
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
|
||
return migrated_count
|
||
|
||
|
||
def migrate_jargons(
|
||
old_rows: Iterable[sqlite3.Row],
|
||
target_session: Session,
|
||
jargon_columns: set[str],
|
||
jargon_nullable_map: dict[str, bool],
|
||
) -> int:
|
||
"""迁移 jargon 数据。"""
|
||
migrated_count = 0
|
||
coerced_meaning_null_count = 0
|
||
for row in old_rows:
|
||
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
|
||
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
|
||
raw_content_list = normalize_string_list(raw_content_value)
|
||
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
|
||
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
|
||
inference_content_key = (
|
||
"inference_content_only"
|
||
if "inference_content_only" in jargon_columns
|
||
else "inference_with_content_only"
|
||
if "inference_with_content_only" in jargon_columns
|
||
else None
|
||
)
|
||
|
||
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
|
||
|
||
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
|
||
meaning_value = ""
|
||
coerced_meaning_null_count += 1
|
||
|
||
# 显式执行 SQL,避免 ORM 在 None 场景下回填模型默认值。
|
||
target_session.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO jargons (
|
||
id,
|
||
content,
|
||
raw_content,
|
||
meaning,
|
||
session_id_dict,
|
||
count,
|
||
is_jargon,
|
||
is_complete,
|
||
is_global,
|
||
last_inference_count,
|
||
inference_with_context,
|
||
inference_with_content_only
|
||
) VALUES (
|
||
:id,
|
||
:content,
|
||
:raw_content,
|
||
:meaning,
|
||
:session_id_dict,
|
||
:count,
|
||
:is_jargon,
|
||
:is_complete,
|
||
:is_global,
|
||
:last_inference_count,
|
||
:inference_with_context,
|
||
:inference_with_content_only
|
||
)
|
||
"""
|
||
),
|
||
{
|
||
"id": int(row["id"]) if row["id"] is not None else None,
|
||
"content": str(row["content"]).strip(),
|
||
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
|
||
"meaning": meaning_value,
|
||
"session_id_dict": build_session_id_dict(
|
||
row["chat_id"] if "chat_id" in jargon_columns else None,
|
||
fallback_count=count,
|
||
),
|
||
"count": count,
|
||
"is_jargon": is_jargon_value,
|
||
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
|
||
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
|
||
"last_inference_count": (
|
||
int(row["last_inference_count"])
|
||
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
|
||
else 0
|
||
),
|
||
"inference_with_context": (
|
||
str(row["inference_with_context"])
|
||
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
|
||
else None
|
||
),
|
||
"inference_with_content_only": (
|
||
str(row[inference_content_key])
|
||
if inference_content_key and row[inference_content_key] is not None
|
||
else None
|
||
),
|
||
},
|
||
)
|
||
migrated_count += 1
|
||
|
||
if coerced_meaning_null_count > 0:
|
||
print(
|
||
f"警告:目标表 jargons.meaning 不允许 NULL,已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
|
||
)
|
||
return migrated_count
|
||
|
||
|
||
def confirm_target_replacement(target_db_path: Path, clear_target: bool) -> bool:
|
||
"""确认是否写入目标数据库。"""
|
||
if clear_target:
|
||
return prompt_yes_no(f"将清空目标库中的 expressions/jargons 后再迁移,确认继续吗?\n目标库:{target_db_path}")
|
||
return prompt_yes_no(f"将写入目标库,若主键冲突会导致迁移失败,确认继续吗?\n目标库:{target_db_path}")
|
||
|
||
|
||
def parse_arguments() -> Namespace:
|
||
"""解析参数。"""
|
||
return build_argument_parser().parse_args()
|
||
|
||
|
||
def main() -> None:
|
||
"""脚本入口。"""
|
||
args = parse_arguments()
|
||
|
||
print("旧版 expression/jargon -> 新版 expressions/jargons 迁移工具")
|
||
source_db_path = prompt_path("请输入旧版数据库路径", args.source_db)
|
||
target_db_path = prompt_path("请输入新版数据库路径", args.target_db)
|
||
clear_target = args.clear_target or prompt_yes_no("迁移前是否清空目标库中的 expressions 和 jargons 表?", False)
|
||
|
||
if source_db_path == target_db_path:
|
||
raise ValueError("旧版数据库路径和新版数据库路径不能相同。")
|
||
|
||
ensure_sqlite_file(source_db_path, should_exist=True)
|
||
ensure_sqlite_file(target_db_path, should_exist=False)
|
||
|
||
print(f"旧库:{source_db_path}")
|
||
print(f"新库:{target_db_path}")
|
||
print(f"清空目标表:{'是' if clear_target else '否'}")
|
||
|
||
if not confirm_target_replacement(target_db_path, clear_target):
|
||
print("已取消迁移。")
|
||
return
|
||
|
||
source_connection = connect_sqlite(source_db_path)
|
||
try:
|
||
expression_table_name = resolve_source_table_name(source_connection, ["expression", "expressions"])
|
||
jargon_table_name = resolve_source_table_name(source_connection, ["jargon", "jargons"])
|
||
expression_columns = get_table_columns(source_connection, expression_table_name)
|
||
jargon_columns = get_table_columns(source_connection, jargon_table_name)
|
||
expression_rows = load_rows(source_connection, expression_table_name)
|
||
jargon_rows = load_rows(source_connection, jargon_table_name)
|
||
finally:
|
||
source_connection.close()
|
||
|
||
target_engine = create_target_engine(target_db_path)
|
||
SQLModel.metadata.create_all(target_engine)
|
||
|
||
target_sqlite_connection = connect_sqlite(target_db_path)
|
||
try:
|
||
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
|
||
finally:
|
||
target_sqlite_connection.close()
|
||
|
||
with Session(target_engine) as target_session:
|
||
if clear_target:
|
||
clear_target_tables(target_session)
|
||
target_session.commit()
|
||
|
||
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
|
||
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
|
||
target_session.commit()
|
||
|
||
print("迁移完成。")
|
||
print(f"已迁移 expression 记录:{expression_count}")
|
||
print(f"已迁移 jargon 记录:{jargon_count}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|