feat:修复一些bug

This commit is contained in:
SengokuCola
2026-03-29 18:28:56 +08:00
parent 82bbf0fd52
commit 96844a9bf5
8 changed files with 898 additions and 444 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Optional
import json
import sqlite3
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, delete
ROOT_PATH = Path(__file__).resolve().parent.parent
@@ -92,11 +93,38 @@ def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[st
return {str(row["name"]) for row in rows}
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
"""获取表字段是否允许 NULL 的映射。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
"""读取整张表的数据。"""
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
def normalize_optional_text(raw_value: Any) -> Optional[str]:
"""标准化可空文本字段。"""
if raw_value is None:
return None
return str(raw_value)
def ensure_nullable_compatibility(
table_name: str,
column_name: str,
row_id: Any,
value: Any,
nullable_map: dict[str, bool],
) -> None:
"""检查待迁移值是否与目标表可空约束兼容。"""
if value is None and not nullable_map.get(column_name, True):
raise ValueError(
f"目标表 {table_name}.{column_name} 不允许 NULL但源记录 id={row_id} 的该字段为 NULL。"
)
def normalize_string_list(raw_value: Any) -> list[str]:
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
if raw_value is None:
@@ -126,14 +154,55 @@ def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
"""标准化审核来源字段。"""
if raw_value is None:
return None
value = str(raw_value).strip().lower()
if value == ModifiedBy.AI.value:
normalized_raw_value = raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text.startswith('"') and raw_text.endswith('"'):
try:
normalized_raw_value = json.loads(raw_text)
except json.JSONDecodeError:
normalized_raw_value = raw_text
else:
normalized_raw_value = raw_text
value = str(normalized_raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
return ModifiedBy.AI
if value == ModifiedBy.USER.value:
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
return ModifiedBy.USER
return None
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
"""解析可空布尔值,兼容整数和字符串。"""
if raw_value is None:
return None
if isinstance(raw_value, bool):
return raw_value
if isinstance(raw_value, int):
return bool(raw_value)
if isinstance(raw_value, float):
return bool(int(raw_value))
value = str(raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {"1", "true", "t", "yes", "y"}:
return True
if value in {"0", "false", "f", "no", "n"}:
return False
raise ValueError(f"无法解析布尔值:{raw_value}")
def parse_bool(raw_value: Any, default: bool = False) -> bool:
"""解析非空布尔值。"""
parsed = parse_optional_bool(raw_value)
return default if parsed is None else parsed
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
"""将旧库中的 Unix 时间戳转换为 datetime。"""
if raw_value is None or raw_value == "":
@@ -212,6 +281,10 @@ def migrate_expressions(
) -> int:
"""迁移 expression 数据。"""
migrated_count = 0
modified_by_ai_count = 0
modified_by_user_count = 0
modified_by_null_count = 0
unknown_modified_by_values: dict[str, int] = {}
for row in old_rows:
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
last_active_time = timestamp_to_datetime(
@@ -219,22 +292,72 @@ def migrate_expressions(
True,
)
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
modified_by = normalize_modified_by(raw_modified_by)
if modified_by == ModifiedBy.AI:
modified_by_ai_count += 1
elif modified_by == ModifiedBy.USER:
modified_by_user_count += 1
else:
modified_by_null_count += 1
if raw_modified_by not in (None, "", "null", "NULL", "None"):
unknown_key = str(raw_modified_by)
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
expression = Expression(
id=int(row["id"]) if row["id"] is not None else None,
situation=str(row["situation"]).strip(),
style=str(row["style"]).strip(),
content_list=json.dumps(content_list, ensure_ascii=False),
count=int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
last_active_time=last_active_time or datetime.now(),
create_time=create_time or datetime.now(),
session_id=str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
checked=bool(row["checked"]) if "checked" in expression_columns and row["checked"] is not None else False,
rejected=bool(row["rejected"]) if "rejected" in expression_columns and row["rejected"] is not None else False,
modified_by=normalize_modified_by(row["modified_by"] if "modified_by" in expression_columns else None),
target_session.execute(
text(
"""
INSERT INTO expressions (
id,
situation,
style,
content_list,
count,
last_active_time,
create_time,
session_id,
checked,
rejected,
modified_by
) VALUES (
:id,
:situation,
:style,
:content_list,
:count,
:last_active_time,
:create_time,
:session_id,
:checked,
:rejected,
:modified_by
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"situation": str(row["situation"]).strip(),
"style": str(row["style"]).strip(),
"content_list": json.dumps(content_list, ensure_ascii=False),
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
"last_active_time": last_active_time or datetime.now(),
"create_time": create_time or datetime.now(),
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
"modified_by": modified_by.name if modified_by is not None else None,
},
)
target_session.add(expression)
migrated_count += 1
print(
"Expression modified_by 迁移统计:"
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
)
if unknown_modified_by_values:
preview_items = list(unknown_modified_by_values.items())[:10]
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
return migrated_count
@@ -242,12 +365,17 @@ def migrate_jargons(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
jargon_columns: set[str],
jargon_nullable_map: dict[str, bool],
) -> int:
"""迁移 jargon 数据。"""
migrated_count = 0
coerced_meaning_null_count = 0
for row in old_rows:
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
raw_content_list = normalize_string_list(row["raw_content"] if "raw_content" in jargon_columns else None)
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
raw_content_list = normalize_string_list(raw_content_value)
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
inference_content_key = (
"inference_content_only"
if "inference_content_only" in jargon_columns
@@ -256,35 +384,81 @@ def migrate_jargons(
else None
)
jargon = Jargon(
id=int(row["id"]) if row["id"] is not None else None,
content=str(row["content"]).strip(),
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
meaning=str(row["meaning"]).strip() if "meaning" in jargon_columns and row["meaning"] is not None else "",
session_id_dict=build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
count=count,
is_jargon=bool(row["is_jargon"]) if "is_jargon" in jargon_columns and row["is_jargon"] is not None else None,
is_complete=bool(row["is_complete"]) if "is_complete" in jargon_columns and row["is_complete"] is not None else False,
is_global=bool(row["is_global"]) if "is_global" in jargon_columns and row["is_global"] is not None else False,
last_inference_count=(
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
inference_with_context=(
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
inference_with_content_only=(
str(row[inference_content_key]) if inference_content_key and row[inference_content_key] is not None else None
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
meaning_value = ""
coerced_meaning_null_count += 1
# 显式执行 SQL避免 ORM 在 None 场景下回填模型默认值。
target_session.execute(
text(
"""
INSERT INTO jargons (
id,
content,
raw_content,
meaning,
session_id_dict,
count,
is_jargon,
is_complete,
is_global,
last_inference_count,
inference_with_context,
inference_with_content_only
) VALUES (
:id,
:content,
:raw_content,
:meaning,
:session_id_dict,
:count,
:is_jargon,
:is_complete,
:is_global,
:last_inference_count,
:inference_with_context,
:inference_with_content_only
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"content": str(row["content"]).strip(),
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
"meaning": meaning_value,
"session_id_dict": build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
"count": count,
"is_jargon": is_jargon_value,
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
"last_inference_count": (
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
"inference_with_context": (
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
"inference_with_content_only": (
str(row[inference_content_key])
if inference_content_key and row[inference_content_key] is not None
else None
),
},
)
target_session.add(jargon)
migrated_count += 1
if coerced_meaning_null_count > 0:
print(
f"警告:目标表 jargons.meaning 不允许 NULL已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
)
return migrated_count
@@ -337,13 +511,19 @@ def main() -> None:
target_engine = create_target_engine(target_db_path)
SQLModel.metadata.create_all(target_engine)
target_sqlite_connection = connect_sqlite(target_db_path)
try:
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
finally:
target_sqlite_connection.close()
with Session(target_engine) as target_session:
if clear_target:
clear_target_tables(target_session)
target_session.commit()
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
target_session.commit()
print("迁移完成。")