Files
mai-bot/code_scripts/migrate_expression_jargon_db.py

536 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
from sys import path as sys_path
from typing import Any, Optional
import json
import sqlite3
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, delete
ROOT_PATH = Path(__file__).resolve().parent.parent
if str(ROOT_PATH) not in sys_path:
sys_path.insert(0, str(ROOT_PATH))
from src.common.database.database_model import Expression, Jargon, ModifiedBy # noqa: E402
def build_argument_parser() -> ArgumentParser:
"""构建命令行参数解析器。"""
parser = ArgumentParser(
description="将旧版 expression/jargon 数据迁移到新版 expressions/jargons 数据库。"
)
parser.add_argument("--source-db", dest="source_db", help="旧版 SQLite 数据库路径")
parser.add_argument("--target-db", dest="target_db", help="新版 SQLite 数据库路径")
parser.add_argument(
"--clear-target",
dest="clear_target",
action="store_true",
help="迁移前清空目标库中的 expressions 和 jargons 表",
)
return parser
def prompt_path(prompt_text: str, current_value: Optional[str] = None) -> Path:
"""读取数据库路径输入。"""
while True:
suffix = f" [{current_value}]" if current_value else ""
raw_text = input(f"{prompt_text}{suffix}: ").strip()
value = raw_text or current_value or ""
if not value:
print("路径不能为空,请重新输入。")
continue
return Path(value).expanduser().resolve()
def prompt_yes_no(prompt_text: str, default: bool = False) -> bool:
"""读取是否确认输入。"""
default_hint = "Y/n" if default else "y/N"
raw_text = input(f"{prompt_text} [{default_hint}]: ").strip().lower()
if not raw_text:
return default
return raw_text in {"y", "yes"}
def ensure_sqlite_file(path: Path, should_exist: bool) -> None:
"""校验 SQLite 文件路径。"""
if should_exist and not path.is_file():
raise FileNotFoundError(f"数据库文件不存在:{path}")
if not should_exist:
path.parent.mkdir(parents=True, exist_ok=True)
def connect_sqlite(path: Path) -> sqlite3.Connection:
"""创建 SQLite 连接。"""
connection = sqlite3.connect(path)
connection.row_factory = sqlite3.Row
return connection
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:
"""检查表是否存在。"""
result = connection.execute(
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? LIMIT 1",
(table_name,),
).fetchone()
return result is not None
def resolve_source_table_name(connection: sqlite3.Connection, candidates: list[str]) -> str:
"""从候选表名中解析实际存在的表名。"""
for table_name in candidates:
if table_exists(connection, table_name):
return table_name
raise ValueError(f"未找到候选表:{', '.join(candidates)}")
def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[str]:
"""获取表字段名集合。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]) for row in rows}
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
"""获取表字段是否允许 NULL 的映射。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
"""读取整张表的数据。"""
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
def normalize_optional_text(raw_value: Any) -> Optional[str]:
"""标准化可空文本字段。"""
if raw_value is None:
return None
return str(raw_value)
def ensure_nullable_compatibility(
table_name: str,
column_name: str,
row_id: Any,
value: Any,
nullable_map: dict[str, bool],
) -> None:
"""检查待迁移值是否与目标表可空约束兼容。"""
if value is None and not nullable_map.get(column_name, True):
raise ValueError(
f"目标表 {table_name}.{column_name} 不允许 NULL但源记录 id={row_id} 的该字段为 NULL。"
)
def normalize_string_list(raw_value: Any) -> list[str]:
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
if raw_value is None:
return []
if isinstance(raw_value, list):
return [str(item).strip() for item in raw_value if str(item).strip()]
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if not raw_text:
return []
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return [raw_text]
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
if isinstance(parsed, str):
parsed_text = parsed.strip()
return [parsed_text] if parsed_text else []
if parsed is None:
return []
return [str(parsed).strip()]
return [str(raw_value).strip()]
def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
"""标准化审核来源字段。"""
if raw_value is None:
return None
normalized_raw_value = raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text.startswith('"') and raw_text.endswith('"'):
try:
normalized_raw_value = json.loads(raw_text)
except json.JSONDecodeError:
normalized_raw_value = raw_text
else:
normalized_raw_value = raw_text
value = str(normalized_raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
return ModifiedBy.AI
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
return ModifiedBy.USER
return None
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
"""解析可空布尔值,兼容整数和字符串。"""
if raw_value is None:
return None
if isinstance(raw_value, bool):
return raw_value
if isinstance(raw_value, int):
return bool(raw_value)
if isinstance(raw_value, float):
return bool(int(raw_value))
value = str(raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {"1", "true", "t", "yes", "y"}:
return True
if value in {"0", "false", "f", "no", "n"}:
return False
raise ValueError(f"无法解析布尔值:{raw_value}")
def parse_bool(raw_value: Any, default: bool = False) -> bool:
"""解析非空布尔值。"""
parsed = parse_optional_bool(raw_value)
return default if parsed is None else parsed
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
"""将旧库中的 Unix 时间戳转换为 datetime。"""
if raw_value is None or raw_value == "":
return datetime.now() if fallback_now else None
if isinstance(raw_value, datetime):
return raw_value
try:
return datetime.fromtimestamp(float(raw_value))
except (TypeError, ValueError, OSError, OverflowError):
return datetime.now() if fallback_now else None
def build_session_id_dict(raw_chat_id: Any, fallback_count: int) -> str:
"""将旧版 jargon.chat_id 转换为新版 session_id_dict。"""
if raw_chat_id is None:
return json.dumps({}, ensure_ascii=False)
if isinstance(raw_chat_id, str):
raw_text = raw_chat_id.strip()
else:
raw_text = str(raw_chat_id).strip()
if not raw_text:
return json.dumps({}, ensure_ascii=False)
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return json.dumps({raw_text: max(fallback_count, 1)}, ensure_ascii=False)
if isinstance(parsed, str):
parsed_text = parsed.strip()
session_counts = {parsed_text: max(fallback_count, 1)} if parsed_text else {}
return json.dumps(session_counts, ensure_ascii=False)
if not isinstance(parsed, list):
return json.dumps({}, ensure_ascii=False)
session_counts: dict[str, int] = {}
for item in parsed:
if not isinstance(item, list) or not item:
continue
session_id = str(item[0]).strip()
if not session_id:
continue
item_count = 1
if len(item) > 1:
try:
item_count = int(item[1])
except (TypeError, ValueError):
item_count = 1
session_counts[session_id] = max(item_count, 1)
return json.dumps(session_counts, ensure_ascii=False)
def create_target_engine(target_db_path: Path):
"""创建目标数据库引擎。"""
return create_engine(
f"sqlite:///{target_db_path.as_posix()}",
echo=False,
connect_args={"check_same_thread": False},
)
def clear_target_tables(session: Session) -> None:
"""清空目标表。"""
session.exec(delete(Expression))
session.exec(delete(Jargon))
def migrate_expressions(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
expression_columns: set[str],
) -> int:
"""迁移 expression 数据。"""
migrated_count = 0
modified_by_ai_count = 0
modified_by_user_count = 0
modified_by_null_count = 0
unknown_modified_by_values: dict[str, int] = {}
for row in old_rows:
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
last_active_time = timestamp_to_datetime(
row["last_active_time"] if "last_active_time" in expression_columns else None,
True,
)
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
modified_by = normalize_modified_by(raw_modified_by)
if modified_by == ModifiedBy.AI:
modified_by_ai_count += 1
elif modified_by == ModifiedBy.USER:
modified_by_user_count += 1
else:
modified_by_null_count += 1
if raw_modified_by not in (None, "", "null", "NULL", "None"):
unknown_key = str(raw_modified_by)
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
target_session.execute(
text(
"""
INSERT INTO expressions (
id,
situation,
style,
content_list,
count,
last_active_time,
create_time,
session_id,
checked,
rejected,
modified_by
) VALUES (
:id,
:situation,
:style,
:content_list,
:count,
:last_active_time,
:create_time,
:session_id,
:checked,
:rejected,
:modified_by
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"situation": str(row["situation"]).strip(),
"style": str(row["style"]).strip(),
"content_list": json.dumps(content_list, ensure_ascii=False),
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
"last_active_time": last_active_time or datetime.now(),
"create_time": create_time or datetime.now(),
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
"modified_by": modified_by.name if modified_by is not None else None,
},
)
migrated_count += 1
print(
"Expression modified_by 迁移统计:"
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
)
if unknown_modified_by_values:
preview_items = list(unknown_modified_by_values.items())[:10]
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
return migrated_count
def migrate_jargons(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
jargon_columns: set[str],
jargon_nullable_map: dict[str, bool],
) -> int:
"""迁移 jargon 数据。"""
migrated_count = 0
coerced_meaning_null_count = 0
for row in old_rows:
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
raw_content_list = normalize_string_list(raw_content_value)
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
inference_content_key = (
"inference_content_only"
if "inference_content_only" in jargon_columns
else "inference_with_content_only"
if "inference_with_content_only" in jargon_columns
else None
)
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
meaning_value = ""
coerced_meaning_null_count += 1
# 显式执行 SQL避免 ORM 在 None 场景下回填模型默认值。
target_session.execute(
text(
"""
INSERT INTO jargons (
id,
content,
raw_content,
meaning,
session_id_dict,
count,
is_jargon,
is_complete,
is_global,
last_inference_count,
inference_with_context,
inference_with_content_only
) VALUES (
:id,
:content,
:raw_content,
:meaning,
:session_id_dict,
:count,
:is_jargon,
:is_complete,
:is_global,
:last_inference_count,
:inference_with_context,
:inference_with_content_only
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"content": str(row["content"]).strip(),
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
"meaning": meaning_value,
"session_id_dict": build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
"count": count,
"is_jargon": is_jargon_value,
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
"last_inference_count": (
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
"inference_with_context": (
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
"inference_with_content_only": (
str(row[inference_content_key])
if inference_content_key and row[inference_content_key] is not None
else None
),
},
)
migrated_count += 1
if coerced_meaning_null_count > 0:
print(
f"警告:目标表 jargons.meaning 不允许 NULL已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
)
return migrated_count
def confirm_target_replacement(target_db_path: Path, clear_target: bool) -> bool:
"""确认是否写入目标数据库。"""
if clear_target:
return prompt_yes_no(f"将清空目标库中的 expressions/jargons 后再迁移,确认继续吗?\n目标库:{target_db_path}")
return prompt_yes_no(f"将写入目标库,若主键冲突会导致迁移失败,确认继续吗?\n目标库:{target_db_path}")
def parse_arguments() -> Namespace:
"""解析参数。"""
return build_argument_parser().parse_args()
def main() -> None:
"""脚本入口。"""
args = parse_arguments()
print("旧版 expression/jargon -> 新版 expressions/jargons 迁移工具")
source_db_path = prompt_path("请输入旧版数据库路径", args.source_db)
target_db_path = prompt_path("请输入新版数据库路径", args.target_db)
clear_target = args.clear_target or prompt_yes_no("迁移前是否清空目标库中的 expressions 和 jargons 表?", False)
if source_db_path == target_db_path:
raise ValueError("旧版数据库路径和新版数据库路径不能相同。")
ensure_sqlite_file(source_db_path, should_exist=True)
ensure_sqlite_file(target_db_path, should_exist=False)
print(f"旧库:{source_db_path}")
print(f"新库:{target_db_path}")
print(f"清空目标表:{'' if clear_target else ''}")
if not confirm_target_replacement(target_db_path, clear_target):
print("已取消迁移。")
return
source_connection = connect_sqlite(source_db_path)
try:
expression_table_name = resolve_source_table_name(source_connection, ["expression", "expressions"])
jargon_table_name = resolve_source_table_name(source_connection, ["jargon", "jargons"])
expression_columns = get_table_columns(source_connection, expression_table_name)
jargon_columns = get_table_columns(source_connection, jargon_table_name)
expression_rows = load_rows(source_connection, expression_table_name)
jargon_rows = load_rows(source_connection, jargon_table_name)
finally:
source_connection.close()
target_engine = create_target_engine(target_db_path)
SQLModel.metadata.create_all(target_engine)
target_sqlite_connection = connect_sqlite(target_db_path)
try:
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
finally:
target_sqlite_connection.close()
with Session(target_engine) as target_session:
if clear_target:
clear_target_tables(target_session)
target_session.commit()
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
target_session.commit()
print("迁移完成。")
print(f"已迁移 expression 记录:{expression_count}")
print(f"已迁移 jargon 记录:{jargon_count}")
if __name__ == "__main__":
main()