chore: import deployable mai-bot source tree

This commit is contained in:
2026-05-11 00:51:12 +00:00
parent 4813699b3e
commit 7a54015f94
1009 changed files with 312999 additions and 16 deletions

View File

@@ -0,0 +1,127 @@
from pathlib import Path
import ast
import subprocess
import sys
base_file_path = Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_model.py"
target_file_path = (
Path(__file__).parent.parent.absolute().resolve() / "src" / "common" / "database" / "database_datamodel.py"
)
with open(base_file_path, "r", encoding="utf-8") as f:
source_text = f.read()
source_lines = source_text.splitlines()
try:
tree = ast.parse(source_text)
except SyntaxError as e:
raise e
code_lines = [
"from typing import Optional",
"from pydantic import BaseModel",
"from datetime import datetime",
"from .database_model import ModelUser, ImageType",
]
def src(node):
seg = ast.get_source_segment(source_text, node)
return seg if seg is not None else ast.unparse(node)
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
# 判断是否 SQLModel 且 table=True
has_sqlmodel = any(
(isinstance(b, ast.Name) and b.id == "SQLModel") or (isinstance(b, ast.Attribute) and b.attr == "SQLModel")
for b in node.bases
)
has_table_kw = any(
(kw.arg == "table" and isinstance(kw.value, ast.Constant) and kw.value.value is True) for kw in node.keywords
)
if not (has_sqlmodel and has_table_kw):
continue
class_name = node.name
code_lines.append("")
code_lines.append(f"class {class_name}(BaseModel):")
fields_added = 0
for item in node.body:
# 跳过 __tablename__ 等
if isinstance(item, ast.Assign):
if len(item.targets) != 1 or not isinstance(item.targets[0], ast.Name):
continue
name = item.targets[0].id
if name == "__tablename__":
continue
value_src = src(item.value)
line = f" {name} = {value_src}"
fields_added += 1
lineno = getattr(item, "lineno", None)
elif isinstance(item, ast.AnnAssign):
# 注解赋值
if not isinstance(item.target, ast.Name):
continue
name = item.target.id
ann = src(item.annotation) if item.annotation is not None else None
if item.value is None:
line = f" {name}: {ann}" if ann else f" {name}"
elif isinstance(item.value, ast.Call) and (
(isinstance(item.value.func, ast.Name) and item.value.func.id == "Field")
or (isinstance(item.value.func, ast.Attribute) and item.value.func.attr == "Field")
):
default_kw = next((kw for kw in item.value.keywords if kw.arg == "default"), None)
if default_kw is None:
# 没有 default保留类型但不赋值
line = f" {name}: {ann}" if ann else f" {name}"
else:
default_src = src(default_kw.value)
line = f" {name}: {ann} = {default_src}"
else:
value_src = src(item.value)
line = f" {name}: {ann} = {value_src}" if ann else f" {name} = {value_src}"
fields_added += 1
lineno = getattr(item, "lineno", None)
else:
continue
# 提取同一行的行内注释作为字段说明(如果存在)
comment = None
if lineno is not None:
src_line = source_lines[lineno - 1]
if "#" in src_line:
# 取第一个 #
comment = src_line.split("#", 1)[1].strip()
# 避免三引号冲突
comment = comment.replace('"""', '\\"""')
code_lines.append(line)
if comment:
code_lines.append(f' """{comment}"""')
else:
print(f"Warning: No comment found for field '{name}' in class '{class_name}'.")
if fields_added == 0:
code_lines.append(" pass")
with open(target_file_path, "w", encoding="utf-8") as f:
f.write("\n".join(code_lines) + "\n")
try:
result = subprocess.run(["ruff", "format", str(target_file_path)], capture_output=True, text=True)
except FileNotFoundError:
print("ruff 未找到,请安装 ruff 并确保其在 PATH 中例如pip install ruff", file=sys.stderr)
sys.exit(127)
# 输出 ruff 的 stdout/stderr
if result.stdout:
print(result.stdout, end="")
if result.stderr:
print(result.stderr, file=sys.stderr, end="")
if result.returncode != 0:
print(f"ruff 检查失败,退出码:{result.returncode}", file=sys.stderr)
sys.exit(result.returncode)

View File

@@ -0,0 +1,535 @@
from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
from sys import path as sys_path
from typing import Any, Optional
import json
import sqlite3
from sqlalchemy import text
from sqlmodel import Session, SQLModel, create_engine, delete
ROOT_PATH = Path(__file__).resolve().parent.parent
if str(ROOT_PATH) not in sys_path:
sys_path.insert(0, str(ROOT_PATH))
from src.common.database.database_model import Expression, Jargon, ModifiedBy # noqa: E402
def build_argument_parser() -> ArgumentParser:
"""构建命令行参数解析器。"""
parser = ArgumentParser(
description="将旧版 expression/jargon 数据迁移到新版 expressions/jargons 数据库。"
)
parser.add_argument("--source-db", dest="source_db", help="旧版 SQLite 数据库路径")
parser.add_argument("--target-db", dest="target_db", help="新版 SQLite 数据库路径")
parser.add_argument(
"--clear-target",
dest="clear_target",
action="store_true",
help="迁移前清空目标库中的 expressions 和 jargons 表",
)
return parser
def prompt_path(prompt_text: str, current_value: Optional[str] = None) -> Path:
"""读取数据库路径输入。"""
while True:
suffix = f" [{current_value}]" if current_value else ""
raw_text = input(f"{prompt_text}{suffix}: ").strip()
value = raw_text or current_value or ""
if not value:
print("路径不能为空,请重新输入。")
continue
return Path(value).expanduser().resolve()
def prompt_yes_no(prompt_text: str, default: bool = False) -> bool:
"""读取是否确认输入。"""
default_hint = "Y/n" if default else "y/N"
raw_text = input(f"{prompt_text} [{default_hint}]: ").strip().lower()
if not raw_text:
return default
return raw_text in {"y", "yes"}
def ensure_sqlite_file(path: Path, should_exist: bool) -> None:
"""校验 SQLite 文件路径。"""
if should_exist and not path.is_file():
raise FileNotFoundError(f"数据库文件不存在:{path}")
if not should_exist:
path.parent.mkdir(parents=True, exist_ok=True)
def connect_sqlite(path: Path) -> sqlite3.Connection:
"""创建 SQLite 连接。"""
connection = sqlite3.connect(path)
connection.row_factory = sqlite3.Row
return connection
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:
"""检查表是否存在。"""
result = connection.execute(
"SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = ? LIMIT 1",
(table_name,),
).fetchone()
return result is not None
def resolve_source_table_name(connection: sqlite3.Connection, candidates: list[str]) -> str:
"""从候选表名中解析实际存在的表名。"""
for table_name in candidates:
if table_exists(connection, table_name):
return table_name
raise ValueError(f"未找到候选表:{', '.join(candidates)}")
def get_table_columns(connection: sqlite3.Connection, table_name: str) -> set[str]:
"""获取表字段名集合。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]) for row in rows}
def get_table_nullable_map(connection: sqlite3.Connection, table_name: str) -> dict[str, bool]:
"""获取表字段是否允许 NULL 的映射。"""
rows = connection.execute(f"PRAGMA table_info('{table_name}')").fetchall()
return {str(row["name"]): not bool(row["notnull"]) for row in rows}
def load_rows(connection: sqlite3.Connection, table_name: str) -> list[sqlite3.Row]:
"""读取整张表的数据。"""
return connection.execute(f"SELECT * FROM {table_name}").fetchall()
def normalize_optional_text(raw_value: Any) -> Optional[str]:
"""标准化可空文本字段。"""
if raw_value is None:
return None
return str(raw_value)
def ensure_nullable_compatibility(
table_name: str,
column_name: str,
row_id: Any,
value: Any,
nullable_map: dict[str, bool],
) -> None:
"""检查待迁移值是否与目标表可空约束兼容。"""
if value is None and not nullable_map.get(column_name, True):
raise ValueError(
f"目标表 {table_name}.{column_name} 不允许 NULL但源记录 id={row_id} 的该字段为 NULL。"
)
def normalize_string_list(raw_value: Any) -> list[str]:
"""将旧库中的 JSON/文本字段标准化为字符串列表。"""
if raw_value is None:
return []
if isinstance(raw_value, list):
return [str(item).strip() for item in raw_value if str(item).strip()]
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if not raw_text:
return []
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return [raw_text]
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()]
if isinstance(parsed, str):
parsed_text = parsed.strip()
return [parsed_text] if parsed_text else []
if parsed is None:
return []
return [str(parsed).strip()]
return [str(raw_value).strip()]
def normalize_modified_by(raw_value: Any) -> Optional[ModifiedBy]:
"""标准化审核来源字段。"""
if raw_value is None:
return None
normalized_raw_value = raw_value
if isinstance(raw_value, str):
raw_text = raw_value.strip()
if raw_text.startswith('"') and raw_text.endswith('"'):
try:
normalized_raw_value = json.loads(raw_text)
except json.JSONDecodeError:
normalized_raw_value = raw_text
else:
normalized_raw_value = raw_text
value = str(normalized_raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {ModifiedBy.AI.value, ModifiedBy.AI.name.lower()}:
return ModifiedBy.AI
if value in {ModifiedBy.USER.value, ModifiedBy.USER.name.lower()}:
return ModifiedBy.USER
return None
def parse_optional_bool(raw_value: Any) -> Optional[bool]:
"""解析可空布尔值,兼容整数和字符串。"""
if raw_value is None:
return None
if isinstance(raw_value, bool):
return raw_value
if isinstance(raw_value, int):
return bool(raw_value)
if isinstance(raw_value, float):
return bool(int(raw_value))
value = str(raw_value).strip().lower()
if value in {"", "none", "null"}:
return None
if value in {"1", "true", "t", "yes", "y"}:
return True
if value in {"0", "false", "f", "no", "n"}:
return False
raise ValueError(f"无法解析布尔值:{raw_value}")
def parse_bool(raw_value: Any, default: bool = False) -> bool:
"""解析非空布尔值。"""
parsed = parse_optional_bool(raw_value)
return default if parsed is None else parsed
def timestamp_to_datetime(raw_value: Any, fallback_now: bool) -> Optional[datetime]:
"""将旧库中的 Unix 时间戳转换为 datetime。"""
if raw_value is None or raw_value == "":
return datetime.now() if fallback_now else None
if isinstance(raw_value, datetime):
return raw_value
try:
return datetime.fromtimestamp(float(raw_value))
except (TypeError, ValueError, OSError, OverflowError):
return datetime.now() if fallback_now else None
def build_session_id_dict(raw_chat_id: Any, fallback_count: int) -> str:
"""将旧版 jargon.chat_id 转换为新版 session_id_dict。"""
if raw_chat_id is None:
return json.dumps({}, ensure_ascii=False)
if isinstance(raw_chat_id, str):
raw_text = raw_chat_id.strip()
else:
raw_text = str(raw_chat_id).strip()
if not raw_text:
return json.dumps({}, ensure_ascii=False)
try:
parsed = json.loads(raw_text)
except json.JSONDecodeError:
return json.dumps({raw_text: max(fallback_count, 1)}, ensure_ascii=False)
if isinstance(parsed, str):
parsed_text = parsed.strip()
session_counts = {parsed_text: max(fallback_count, 1)} if parsed_text else {}
return json.dumps(session_counts, ensure_ascii=False)
if not isinstance(parsed, list):
return json.dumps({}, ensure_ascii=False)
session_counts: dict[str, int] = {}
for item in parsed:
if not isinstance(item, list) or not item:
continue
session_id = str(item[0]).strip()
if not session_id:
continue
item_count = 1
if len(item) > 1:
try:
item_count = int(item[1])
except (TypeError, ValueError):
item_count = 1
session_counts[session_id] = max(item_count, 1)
return json.dumps(session_counts, ensure_ascii=False)
def create_target_engine(target_db_path: Path):
"""创建目标数据库引擎。"""
return create_engine(
f"sqlite:///{target_db_path.as_posix()}",
echo=False,
connect_args={"check_same_thread": False},
)
def clear_target_tables(session: Session) -> None:
"""清空目标表。"""
session.exec(delete(Expression))
session.exec(delete(Jargon))
def migrate_expressions(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
expression_columns: set[str],
) -> int:
"""迁移 expression 数据。"""
migrated_count = 0
modified_by_ai_count = 0
modified_by_user_count = 0
modified_by_null_count = 0
unknown_modified_by_values: dict[str, int] = {}
for row in old_rows:
create_time = timestamp_to_datetime(row["create_date"] if "create_date" in expression_columns else None, True)
last_active_time = timestamp_to_datetime(
row["last_active_time"] if "last_active_time" in expression_columns else None,
True,
)
content_list = normalize_string_list(row["content_list"] if "content_list" in expression_columns else None)
raw_modified_by = row["modified_by"] if "modified_by" in expression_columns else None
modified_by = normalize_modified_by(raw_modified_by)
if modified_by == ModifiedBy.AI:
modified_by_ai_count += 1
elif modified_by == ModifiedBy.USER:
modified_by_user_count += 1
else:
modified_by_null_count += 1
if raw_modified_by not in (None, "", "null", "NULL", "None"):
unknown_key = str(raw_modified_by)
unknown_modified_by_values[unknown_key] = unknown_modified_by_values.get(unknown_key, 0) + 1
target_session.execute(
text(
"""
INSERT INTO expressions (
id,
situation,
style,
content_list,
count,
last_active_time,
create_time,
session_id,
checked,
rejected,
modified_by
) VALUES (
:id,
:situation,
:style,
:content_list,
:count,
:last_active_time,
:create_time,
:session_id,
:checked,
:rejected,
:modified_by
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"situation": str(row["situation"]).strip(),
"style": str(row["style"]).strip(),
"content_list": json.dumps(content_list, ensure_ascii=False),
"count": int(row["count"]) if "count" in expression_columns and row["count"] is not None else 1,
"last_active_time": last_active_time or datetime.now(),
"create_time": create_time or datetime.now(),
"session_id": str(row["chat_id"]).strip() if "chat_id" in expression_columns and row["chat_id"] else None,
"checked": parse_bool(row["checked"] if "checked" in expression_columns else None, default=False),
"rejected": parse_bool(row["rejected"] if "rejected" in expression_columns else None, default=False),
"modified_by": modified_by.name if modified_by is not None else None,
},
)
migrated_count += 1
print(
"Expression modified_by 迁移统计:"
f" AI={modified_by_ai_count}, USER={modified_by_user_count}, NULL={modified_by_null_count}"
)
if unknown_modified_by_values:
preview_items = list(unknown_modified_by_values.items())[:10]
preview_text = ", ".join(f"{value!r} x{count}" for value, count in preview_items)
print(f"警告:以下旧 modified_by 值未识别,已按 NULL 迁移:{preview_text}")
return migrated_count
def migrate_jargons(
old_rows: Iterable[sqlite3.Row],
target_session: Session,
jargon_columns: set[str],
jargon_nullable_map: dict[str, bool],
) -> int:
"""迁移 jargon 数据。"""
migrated_count = 0
coerced_meaning_null_count = 0
for row in old_rows:
count = int(row["count"]) if "count" in jargon_columns and row["count"] is not None else 0
raw_content_value = row["raw_content"] if "raw_content" in jargon_columns else None
raw_content_list = normalize_string_list(raw_content_value)
meaning_value = normalize_optional_text(row["meaning"] if "meaning" in jargon_columns else None)
is_jargon_value = parse_optional_bool(row["is_jargon"] if "is_jargon" in jargon_columns else None)
inference_content_key = (
"inference_content_only"
if "inference_content_only" in jargon_columns
else "inference_with_content_only"
if "inference_with_content_only" in jargon_columns
else None
)
ensure_nullable_compatibility("jargons", "is_jargon", row["id"], is_jargon_value, jargon_nullable_map)
if meaning_value is None and not jargon_nullable_map.get("meaning", True):
meaning_value = ""
coerced_meaning_null_count += 1
# 显式执行 SQL避免 ORM 在 None 场景下回填模型默认值。
target_session.execute(
text(
"""
INSERT INTO jargons (
id,
content,
raw_content,
meaning,
session_id_dict,
count,
is_jargon,
is_complete,
is_global,
last_inference_count,
inference_with_context,
inference_with_content_only
) VALUES (
:id,
:content,
:raw_content,
:meaning,
:session_id_dict,
:count,
:is_jargon,
:is_complete,
:is_global,
:last_inference_count,
:inference_with_context,
:inference_with_content_only
)
"""
),
{
"id": int(row["id"]) if row["id"] is not None else None,
"content": str(row["content"]).strip(),
"raw_content": json.dumps(raw_content_list, ensure_ascii=False) if raw_content_value is not None else None,
"meaning": meaning_value,
"session_id_dict": build_session_id_dict(
row["chat_id"] if "chat_id" in jargon_columns else None,
fallback_count=count,
),
"count": count,
"is_jargon": is_jargon_value,
"is_complete": parse_bool(row["is_complete"] if "is_complete" in jargon_columns else None, default=False),
"is_global": parse_bool(row["is_global"] if "is_global" in jargon_columns else None, default=False),
"last_inference_count": (
int(row["last_inference_count"])
if "last_inference_count" in jargon_columns and row["last_inference_count"] is not None
else 0
),
"inference_with_context": (
str(row["inference_with_context"])
if "inference_with_context" in jargon_columns and row["inference_with_context"] is not None
else None
),
"inference_with_content_only": (
str(row[inference_content_key])
if inference_content_key and row[inference_content_key] is not None
else None
),
},
)
migrated_count += 1
if coerced_meaning_null_count > 0:
print(
f"警告:目标表 jargons.meaning 不允许 NULL已将 {coerced_meaning_null_count} 条旧记录的 NULL meaning 转为空字符串。"
)
return migrated_count
def confirm_target_replacement(target_db_path: Path, clear_target: bool) -> bool:
"""确认是否写入目标数据库。"""
if clear_target:
return prompt_yes_no(f"将清空目标库中的 expressions/jargons 后再迁移,确认继续吗?\n目标库:{target_db_path}")
return prompt_yes_no(f"将写入目标库,若主键冲突会导致迁移失败,确认继续吗?\n目标库:{target_db_path}")
def parse_arguments() -> Namespace:
"""解析参数。"""
return build_argument_parser().parse_args()
def main() -> None:
"""脚本入口。"""
args = parse_arguments()
print("旧版 expression/jargon -> 新版 expressions/jargons 迁移工具")
source_db_path = prompt_path("请输入旧版数据库路径", args.source_db)
target_db_path = prompt_path("请输入新版数据库路径", args.target_db)
clear_target = args.clear_target or prompt_yes_no("迁移前是否清空目标库中的 expressions 和 jargons 表?", False)
if source_db_path == target_db_path:
raise ValueError("旧版数据库路径和新版数据库路径不能相同。")
ensure_sqlite_file(source_db_path, should_exist=True)
ensure_sqlite_file(target_db_path, should_exist=False)
print(f"旧库:{source_db_path}")
print(f"新库:{target_db_path}")
print(f"清空目标表:{'' if clear_target else ''}")
if not confirm_target_replacement(target_db_path, clear_target):
print("已取消迁移。")
return
source_connection = connect_sqlite(source_db_path)
try:
expression_table_name = resolve_source_table_name(source_connection, ["expression", "expressions"])
jargon_table_name = resolve_source_table_name(source_connection, ["jargon", "jargons"])
expression_columns = get_table_columns(source_connection, expression_table_name)
jargon_columns = get_table_columns(source_connection, jargon_table_name)
expression_rows = load_rows(source_connection, expression_table_name)
jargon_rows = load_rows(source_connection, jargon_table_name)
finally:
source_connection.close()
target_engine = create_target_engine(target_db_path)
SQLModel.metadata.create_all(target_engine)
target_sqlite_connection = connect_sqlite(target_db_path)
try:
jargon_nullable_map = get_table_nullable_map(target_sqlite_connection, "jargons")
finally:
target_sqlite_connection.close()
with Session(target_engine) as target_session:
if clear_target:
clear_target_tables(target_session)
target_session.commit()
expression_count = migrate_expressions(expression_rows, target_session, expression_columns)
jargon_count = migrate_jargons(jargon_rows, target_session, jargon_columns, jargon_nullable_map)
target_session.commit()
print("迁移完成。")
print(f"已迁移 expression 记录:{expression_count}")
print(f"已迁移 jargon 记录:{jargon_count}")
if __name__ == "__main__":
main()