chore: import deployable mai-bot source tree
This commit is contained in:
336
scripts/analyze_reply_effect_score_correlation.py
Normal file
336
scripts/analyze_reply_effect_score_correlation.py
Normal file
@@ -0,0 +1,336 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from scipy import stats
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import math
|
||||
|
||||
|
||||
DEFAULT_LOG_DIR = Path("logs") / "maisaka_reply_effect"
|
||||
DEFAULT_MANUAL_DIR = Path("logs") / "maisaka_reply_effect_manual"
|
||||
|
||||
|
||||
METRIC_SPECS = [
|
||||
("总分", "asi", "ASI 自动总分"),
|
||||
("大项", "behavior_score", "行为满意度 B"),
|
||||
("大项", "relational_score", "感知质量 R"),
|
||||
("大项", "friction_score", "摩擦风险 F"),
|
||||
("大项", "friction_quality_score", "低摩擦质量分"),
|
||||
("行为子项", "behavior_signals.continue_2turns", "继续两轮"),
|
||||
("行为子项", "behavior_signals.next_user_sentiment", "后续情绪"),
|
||||
("行为子项", "behavior_signals.user_expansion", "用户展开"),
|
||||
("行为子项", "behavior_signals.no_correction", "没有纠正"),
|
||||
("行为子项", "behavior_signals.no_abort", "没有放弃"),
|
||||
("rubric 子项", "rubric_scores.social_presence.normalized_score", "社交临场感"),
|
||||
("rubric 子项", "rubric_scores.warmth.normalized_score", "温暖感"),
|
||||
("rubric 子项", "rubric_scores.competence.normalized_score", "能力/有用性"),
|
||||
("rubric 子项", "rubric_scores.appropriateness.normalized_score", "合适程度"),
|
||||
("rubric 子项", "rubric_scores.uncanny_risk.normalized_score", "违和风险 judge"),
|
||||
("摩擦子项", "friction_signals.explicit_negative", "明确负反馈"),
|
||||
("摩擦子项", "friction_signals.repair_loop", "修复循环"),
|
||||
("摩擦子项", "friction_signals.uncanny_risk", "违和风险"),
|
||||
]
|
||||
|
||||
|
||||
def normalize_name(value: str) -> str:
|
||||
normalized = "".join(char if char.isalnum() or char in "._-" else "_" for char in str(value or "").strip())
|
||||
normalized = normalized.strip("._")
|
||||
return normalized or "unknown"
|
||||
|
||||
|
||||
def load_json_file(file_path: Path) -> dict[str, Any]:
|
||||
try:
|
||||
payload = json.loads(file_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
|
||||
def to_float(value: Any) -> float | None:
|
||||
if value in {None, ""}:
|
||||
return None
|
||||
try:
|
||||
number = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if math.isnan(number) or math.isinf(number):
|
||||
return None
|
||||
return number
|
||||
|
||||
|
||||
def get_nested(payload: dict[str, Any], dotted_path: str) -> Any:
|
||||
current: Any = payload
|
||||
for key in dotted_path.split("."):
|
||||
if not isinstance(current, dict):
|
||||
return None
|
||||
current = current.get(key)
|
||||
return current
|
||||
|
||||
|
||||
def annotation_path(manual_dir: Path, chat_id: str, effect_id: str) -> Path:
|
||||
return manual_dir / normalize_name(chat_id) / f"{normalize_name(effect_id)}.json"
|
||||
|
||||
|
||||
def iter_records(
|
||||
log_dir: Path,
|
||||
manual_dir: Path,
|
||||
*,
|
||||
chat_id: str,
|
||||
include_pending: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
records: list[dict[str, Any]] = []
|
||||
if not log_dir.exists():
|
||||
return records
|
||||
|
||||
chat_dirs = [log_dir / normalize_name(chat_id)] if chat_id else [path for path in log_dir.iterdir() if path.is_dir()]
|
||||
for chat_dir in sorted(chat_dirs):
|
||||
if not chat_dir.exists() or not chat_dir.is_dir():
|
||||
continue
|
||||
for record_file in sorted(chat_dir.glob("*.json")):
|
||||
effect_record = load_json_file(record_file)
|
||||
if not effect_record:
|
||||
continue
|
||||
if not include_pending and effect_record.get("status") != "finalized":
|
||||
continue
|
||||
|
||||
effect_id = str(effect_record.get("effect_id") or record_file.stem)
|
||||
manual_record = load_json_file(annotation_path(manual_dir, chat_dir.name, effect_id))
|
||||
manual_score = to_float(manual_record.get("manual_score"))
|
||||
if manual_score is None:
|
||||
manual_score_5 = to_float(manual_record.get("manual_score_5"))
|
||||
if manual_score_5 is not None:
|
||||
manual_score = (manual_score_5 - 1) / 4 * 100
|
||||
if manual_score is None:
|
||||
continue
|
||||
|
||||
raw_scores = effect_record.get("scores") if isinstance(effect_record.get("scores"), dict) else {}
|
||||
scores = dict(raw_scores)
|
||||
friction_score = to_float(scores.get("friction_score"))
|
||||
if friction_score is not None:
|
||||
scores["friction_quality_score"] = 1 - friction_score
|
||||
records.append(
|
||||
{
|
||||
"chat_id": chat_dir.name,
|
||||
"effect_id": effect_id,
|
||||
"manual_score": manual_score,
|
||||
"manual_score_5": manual_record.get("manual_score_5"),
|
||||
"scores": scores,
|
||||
"status": effect_record.get("status"),
|
||||
"created_at": effect_record.get("created_at"),
|
||||
"record_file": str(record_file),
|
||||
}
|
||||
)
|
||||
return records
|
||||
|
||||
|
||||
def calculate_metric_stats(records: list[dict[str, Any]], metric_path: str, min_n: int) -> dict[str, Any]:
|
||||
pairs: list[tuple[float, float]] = []
|
||||
for record in records:
|
||||
x_value = to_float(get_nested(record["scores"], metric_path))
|
||||
y_value = to_float(record["manual_score"])
|
||||
if x_value is None or y_value is None:
|
||||
continue
|
||||
pairs.append((x_value, y_value))
|
||||
|
||||
x_values = [pair[0] for pair in pairs]
|
||||
y_values = [pair[1] for pair in pairs]
|
||||
result: dict[str, Any] = {
|
||||
"n": len(pairs),
|
||||
"pearson_r": None,
|
||||
"pearson_p": None,
|
||||
"spearman_r": None,
|
||||
"spearman_p": None,
|
||||
"kendall_tau": None,
|
||||
"kendall_p": None,
|
||||
"note": "",
|
||||
}
|
||||
if len(pairs) < min_n:
|
||||
result["note"] = f"样本数少于 {min_n}"
|
||||
return result
|
||||
if len(set(x_values)) < 2:
|
||||
result["note"] = "自动评分没有变化,无法计算相关"
|
||||
return result
|
||||
if len(set(y_values)) < 2:
|
||||
result["note"] = "人工评分没有变化,无法计算相关"
|
||||
return result
|
||||
|
||||
pearson = stats.pearsonr(x_values, y_values)
|
||||
spearman = stats.spearmanr(x_values, y_values)
|
||||
kendall = stats.kendalltau(x_values, y_values)
|
||||
result.update(
|
||||
{
|
||||
"pearson_r": round_float(pearson.statistic),
|
||||
"pearson_p": round_float(pearson.pvalue),
|
||||
"spearman_r": round_float(spearman.statistic),
|
||||
"spearman_p": round_float(spearman.pvalue),
|
||||
"kendall_tau": round_float(kendall.statistic),
|
||||
"kendall_p": round_float(kendall.pvalue),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def round_float(value: Any) -> float | None:
|
||||
number = to_float(value)
|
||||
if number is None:
|
||||
return None
|
||||
return round(number, 6)
|
||||
|
||||
|
||||
def significance_label(p_value: float | None) -> str:
|
||||
if p_value is None:
|
||||
return ""
|
||||
if p_value < 0.001:
|
||||
return "***"
|
||||
if p_value < 0.01:
|
||||
return "**"
|
||||
if p_value < 0.05:
|
||||
return "*"
|
||||
if p_value < 0.1:
|
||||
return "."
|
||||
return "ns"
|
||||
|
||||
|
||||
def build_report(records: list[dict[str, Any]], min_n: int) -> list[dict[str, Any]]:
|
||||
report: list[dict[str, Any]] = []
|
||||
for group, metric_path, label in METRIC_SPECS:
|
||||
metric_stats = calculate_metric_stats(records, metric_path, min_n)
|
||||
report.append(
|
||||
{
|
||||
"group": group,
|
||||
"metric": metric_path,
|
||||
"label": label,
|
||||
**metric_stats,
|
||||
"pearson_sig": significance_label(metric_stats["pearson_p"]),
|
||||
"spearman_sig": significance_label(metric_stats["spearman_p"]),
|
||||
"kendall_sig": significance_label(metric_stats["kendall_p"]),
|
||||
}
|
||||
)
|
||||
return report
|
||||
|
||||
|
||||
def print_report(records: list[dict[str, Any]], report: list[dict[str, Any]]) -> None:
|
||||
chats = sorted({record["chat_id"] for record in records})
|
||||
print("\nMaisaka 回复效果评分相关性分析")
|
||||
print("=" * 96)
|
||||
print(f"已匹配人工评分记录数: {len(records)}")
|
||||
print(f"聊天流数量: {len(chats)}")
|
||||
if chats:
|
||||
print(f"聊天流: {', '.join(chats[:8])}{' ...' if len(chats) > 8 else ''}")
|
||||
print("人工分使用 manual_score,若只有 manual_score_5,则换算到 0-100 后参与计算。")
|
||||
print("显著性: *** p<0.001, ** p<0.01, * p<0.05, . p<0.1, ns 不显著")
|
||||
print("-" * 96)
|
||||
|
||||
header = (
|
||||
f"{'分组':<14} {'指标':<34} {'n':>4} "
|
||||
f"{'Pearson r':>10} {'p':>10} {'sig':>4} "
|
||||
f"{'Spearman r':>11} {'p':>10} {'sig':>4} "
|
||||
f"{'Kendall':>9} {'p':>10} {'说明'}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * 96)
|
||||
for item in report:
|
||||
print(
|
||||
f"{item['group']:<14} "
|
||||
f"{item['label']:<34} "
|
||||
f"{item['n']:>4} "
|
||||
f"{format_number(item['pearson_r']):>10} "
|
||||
f"{format_number(item['pearson_p']):>10} "
|
||||
f"{item['pearson_sig']:>4} "
|
||||
f"{format_number(item['spearman_r']):>11} "
|
||||
f"{format_number(item['spearman_p']):>10} "
|
||||
f"{item['spearman_sig']:>4} "
|
||||
f"{format_number(item['kendall_tau']):>9} "
|
||||
f"{format_number(item['kendall_p']):>10} "
|
||||
f"{item['note']}"
|
||||
)
|
||||
|
||||
total = next((item for item in report if item["metric"] == "asi"), None)
|
||||
if total:
|
||||
print("-" * 96)
|
||||
print(
|
||||
"总分 ASI 与人工分的 Pearson 相关: "
|
||||
f"r={format_number(total['pearson_r'])}, "
|
||||
f"p={format_number(total['pearson_p'])}, "
|
||||
f"显著性={total['pearson_sig'] or 'N/A'}"
|
||||
)
|
||||
|
||||
|
||||
def format_number(value: Any) -> str:
|
||||
if value is None:
|
||||
return "N/A"
|
||||
number = to_float(value)
|
||||
if number is None:
|
||||
return "N/A"
|
||||
if abs(number) < 0.000001:
|
||||
return "0"
|
||||
return f"{number:.4g}"
|
||||
|
||||
|
||||
def write_csv(file_path: Path, report: list[dict[str, Any]]) -> None:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = [
|
||||
"group",
|
||||
"metric",
|
||||
"label",
|
||||
"n",
|
||||
"pearson_r",
|
||||
"pearson_p",
|
||||
"pearson_sig",
|
||||
"spearman_r",
|
||||
"spearman_p",
|
||||
"spearman_sig",
|
||||
"kendall_tau",
|
||||
"kendall_p",
|
||||
"kendall_sig",
|
||||
"note",
|
||||
]
|
||||
with file_path.open("w", encoding="utf-8-sig", newline="") as csv_file:
|
||||
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(report)
|
||||
|
||||
|
||||
def write_json(file_path: Path, records: list[dict[str, Any]], report: list[dict[str, Any]]) -> None:
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"matched_record_count": len(records),
|
||||
"chat_count": len({record["chat_id"] for record in records}),
|
||||
"report": report,
|
||||
}
|
||||
file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="分析 Maisaka 回复效果自动评分与人工评分的相关性和显著性。")
|
||||
parser.add_argument("--log-dir", type=Path, default=DEFAULT_LOG_DIR, help="自动评分 JSON 目录")
|
||||
parser.add_argument("--manual-dir", type=Path, default=DEFAULT_MANUAL_DIR, help="人工评分 JSON 目录")
|
||||
parser.add_argument("--chat-id", default="", help="只分析某个 platform_type_id,例如 qq_group_1028699246")
|
||||
parser.add_argument("--include-pending", action="store_true", help="包含尚未 finalized 的记录")
|
||||
parser.add_argument("--min-n", type=int, default=3, help="计算相关性需要的最小样本数,默认 3")
|
||||
parser.add_argument("--csv", type=Path, default=None, help="把统计结果另存为 CSV")
|
||||
parser.add_argument("--json", type=Path, default=None, help="把统计结果另存为 JSON")
|
||||
args = parser.parse_args()
|
||||
|
||||
records = iter_records(
|
||||
args.log_dir,
|
||||
args.manual_dir,
|
||||
chat_id=args.chat_id,
|
||||
include_pending=args.include_pending,
|
||||
)
|
||||
report = build_report(records, max(2, args.min_n))
|
||||
print_report(records, report)
|
||||
|
||||
if args.csv:
|
||||
write_csv(args.csv, report)
|
||||
print(f"\nCSV 已保存: {args.csv}")
|
||||
if args.json:
|
||||
write_json(args.json, records, report)
|
||||
print(f"JSON 已保存: {args.json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
323
scripts/analyze_tool_usage_by_chat.py
Normal file
323
scripts/analyze_tool_usage_by_chat.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict
|
||||
|
||||
import csv
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "MaiBot.db"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ToolUsageRow:
|
||||
chat_id: str
|
||||
tool_name: str
|
||||
count: int
|
||||
chat_total: int
|
||||
percent_in_chat: float
|
||||
percent_in_all: float
|
||||
|
||||
|
||||
def parse_datetime_filter(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
normalized_value = value.strip()
|
||||
if not normalized_value:
|
||||
return None
|
||||
|
||||
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"):
|
||||
try:
|
||||
parsed = datetime.strptime(normalized_value, fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
return parsed.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
raise ValueError(f"无法解析时间: {value!r},请使用 YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
|
||||
|
||||
|
||||
def parse_recent_filter(value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
normalized_value = value.strip().lower()
|
||||
if not normalized_value:
|
||||
return None
|
||||
|
||||
match = re.fullmatch(r"(\d+(?:\.\d+)?)([mhdw])", normalized_value)
|
||||
if match is None:
|
||||
raise ValueError(f"无法解析最近时间: {value!r},请使用 30m、24h、7d 或 2w")
|
||||
|
||||
amount = float(match.group(1))
|
||||
unit = match.group(2)
|
||||
if amount <= 0:
|
||||
raise ValueError(f"最近时间必须大于 0: {value!r}")
|
||||
|
||||
if unit == "m":
|
||||
delta = timedelta(minutes=amount)
|
||||
elif unit == "h":
|
||||
delta = timedelta(hours=amount)
|
||||
elif unit == "d":
|
||||
delta = timedelta(days=amount)
|
||||
else:
|
||||
delta = timedelta(weeks=amount)
|
||||
|
||||
return (datetime.now() - delta).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def connect_readonly(db_path: Path) -> sqlite3.Connection:
|
||||
if not db_path.exists():
|
||||
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
|
||||
|
||||
database_uri = f"file:{db_path.as_posix()}?mode=ro"
|
||||
connection = sqlite3.connect(database_uri, uri=True)
|
||||
connection.row_factory = sqlite3.Row
|
||||
connection.execute("PRAGMA busy_timeout=5000")
|
||||
return connection
|
||||
|
||||
|
||||
def fetch_tool_counts(
|
||||
db_path: Path,
|
||||
since: str | None,
|
||||
until: str | None,
|
||||
include_empty_chat_id: bool,
|
||||
include_empty_tool_name: bool,
|
||||
) -> list[tuple[str, str, int]]:
|
||||
where_clauses: list[str] = []
|
||||
params: list[str] = []
|
||||
|
||||
if since is not None:
|
||||
where_clauses.append("timestamp >= ?")
|
||||
params.append(since)
|
||||
if until is not None:
|
||||
where_clauses.append("timestamp < ?")
|
||||
params.append(until)
|
||||
if not include_empty_chat_id:
|
||||
where_clauses.append("COALESCE(session_id, '') != ''")
|
||||
if not include_empty_tool_name:
|
||||
where_clauses.append("COALESCE(tool_name, '') != ''")
|
||||
|
||||
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
|
||||
query = f"""
|
||||
SELECT
|
||||
COALESCE(session_id, '') AS chat_id,
|
||||
COALESCE(tool_name, '') AS tool_name,
|
||||
COUNT(*) AS usage_count
|
||||
FROM tool_records
|
||||
{where_sql}
|
||||
GROUP BY COALESCE(session_id, ''), COALESCE(tool_name, '')
|
||||
ORDER BY COALESCE(session_id, ''), usage_count DESC, COALESCE(tool_name, '')
|
||||
"""
|
||||
|
||||
with connect_readonly(db_path) as connection:
|
||||
rows = connection.execute(query, params).fetchall()
|
||||
|
||||
return [(str(row["chat_id"]), str(row["tool_name"]), int(row["usage_count"])) for row in rows]
|
||||
|
||||
|
||||
def build_usage_rows(
|
||||
counts: list[tuple[str, str, int]],
|
||||
min_chat_total: int,
|
||||
top_tools_per_chat: int | None,
|
||||
) -> list[ToolUsageRow]:
|
||||
chat_totals: DefaultDict[str, int] = defaultdict(int)
|
||||
for chat_id, _tool_name, count in counts:
|
||||
chat_totals[chat_id] += count
|
||||
|
||||
all_total = sum(chat_totals.values())
|
||||
rows: list[ToolUsageRow] = []
|
||||
emitted_per_chat: DefaultDict[str, int] = defaultdict(int)
|
||||
|
||||
sorted_counts = sorted(counts, key=lambda item: (item[0], -item[2], item[1]))
|
||||
for chat_id, tool_name, count in sorted_counts:
|
||||
chat_total = chat_totals[chat_id]
|
||||
if chat_total < min_chat_total:
|
||||
continue
|
||||
if top_tools_per_chat is not None and emitted_per_chat[chat_id] >= top_tools_per_chat:
|
||||
continue
|
||||
|
||||
emitted_per_chat[chat_id] += 1
|
||||
rows.append(
|
||||
ToolUsageRow(
|
||||
chat_id=chat_id,
|
||||
tool_name=tool_name,
|
||||
count=count,
|
||||
chat_total=chat_total,
|
||||
percent_in_chat=count / chat_total * 100 if chat_total else 0.0,
|
||||
percent_in_all=count / all_total * 100 if all_total else 0.0,
|
||||
)
|
||||
)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def build_overall_rows(counts: list[tuple[str, str, int]], min_chat_total: int) -> list[tuple[str, int, float]]:
|
||||
chat_totals: DefaultDict[str, int] = defaultdict(int)
|
||||
for chat_id, _tool_name, count in counts:
|
||||
chat_totals[chat_id] += count
|
||||
|
||||
tool_counts: DefaultDict[str, int] = defaultdict(int)
|
||||
for chat_id, tool_name, count in counts:
|
||||
if chat_totals[chat_id] < min_chat_total:
|
||||
continue
|
||||
tool_counts[tool_name] += count
|
||||
|
||||
total = sum(tool_counts.values())
|
||||
sorted_items = sorted(tool_counts.items(), key=lambda item: (-item[1], item[0]))
|
||||
return [(tool_name, count, count / total * 100 if total else 0.0) for tool_name, count in sorted_items]
|
||||
|
||||
|
||||
def print_overall_block(overall_rows: list[tuple[str, int, float]]) -> None:
|
||||
print("全部统计")
|
||||
total = sum(count for _tool_name, count, _percent in overall_rows)
|
||||
print(f"tool_total: {total}")
|
||||
if not overall_rows:
|
||||
print(" 无工具调用记录")
|
||||
return
|
||||
|
||||
tool_width = max(len("tool"), *(len(tool_name) for tool_name, _count, _percent in overall_rows))
|
||||
count_width = max(len("count"), *(len(str(count)) for _tool_name, count, _percent in overall_rows))
|
||||
percent_width = max(len("全局占比"), *(len(f"{percent:.2f}%") for _tool_name, _count, percent in overall_rows))
|
||||
|
||||
print(f" {'tool':<{tool_width}} {'count':>{count_width}} {'全局占比':>{percent_width}}")
|
||||
print(f" {'-' * tool_width} {'-' * count_width} {'-' * percent_width}")
|
||||
for tool_name, count, percent in overall_rows:
|
||||
print(f" {tool_name:<{tool_width}} {count:>{count_width}} {percent:>{percent_width - 1}.2f}%")
|
||||
|
||||
|
||||
def print_markdown(rows: list[ToolUsageRow], overall_rows: list[tuple[str, int, float]]) -> None:
|
||||
print_overall_block(overall_rows)
|
||||
if rows:
|
||||
print()
|
||||
|
||||
grouped_rows: DefaultDict[str, list[ToolUsageRow]] = defaultdict(list)
|
||||
for row in rows:
|
||||
grouped_rows[row.chat_id].append(row)
|
||||
|
||||
first_group = True
|
||||
for chat_id in sorted(grouped_rows):
|
||||
chat_rows = grouped_rows[chat_id]
|
||||
if not chat_rows:
|
||||
continue
|
||||
|
||||
if not first_group:
|
||||
print()
|
||||
first_group = False
|
||||
|
||||
chat_total = chat_rows[0].chat_total
|
||||
print(f"chat_id: {chat_id}")
|
||||
print(f"tool_total: {chat_total}")
|
||||
|
||||
tool_width = max(len("tool"), *(len(row.tool_name) for row in chat_rows))
|
||||
count_width = max(len("count"), *(len(str(row.count)) for row in chat_rows))
|
||||
chat_percent_width = max(len("chat内占比"), *(len(f"{row.percent_in_chat:.2f}%") for row in chat_rows))
|
||||
all_percent_width = max(len("全局占比"), *(len(f"{row.percent_in_all:.2f}%") for row in chat_rows))
|
||||
|
||||
header = (
|
||||
f" {'tool':<{tool_width}} "
|
||||
f"{'count':>{count_width}} "
|
||||
f"{'chat内占比':>{chat_percent_width}} "
|
||||
f"{'全局占比':>{all_percent_width}}"
|
||||
)
|
||||
print(header)
|
||||
print(
|
||||
f" {'-' * tool_width} "
|
||||
f"{'-' * count_width} "
|
||||
f"{'-' * chat_percent_width} "
|
||||
f"{'-' * all_percent_width}"
|
||||
)
|
||||
for row in chat_rows:
|
||||
print(
|
||||
f" {row.tool_name:<{tool_width}} "
|
||||
f"{row.count:>{count_width}} "
|
||||
f"{row.percent_in_chat:>{chat_percent_width - 1}.2f}% "
|
||||
f"{row.percent_in_all:>{all_percent_width - 1}.2f}%"
|
||||
)
|
||||
|
||||
|
||||
def print_json(rows: list[ToolUsageRow]) -> None:
|
||||
payload = [
|
||||
{
|
||||
"chat_id": row.chat_id,
|
||||
"tool_name": row.tool_name,
|
||||
"count": row.count,
|
||||
"chat_total": row.chat_total,
|
||||
"percent_in_chat": round(row.percent_in_chat, 4),
|
||||
"percent_in_all": round(row.percent_in_all, 4),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
print(json.dumps(payload, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def print_csv(rows: list[ToolUsageRow]) -> None:
|
||||
writer = csv.writer(sys.stdout)
|
||||
writer.writerow(["chat_id", "tool_name", "count", "chat_total", "percent_in_chat", "percent_in_all"])
|
||||
for row in rows:
|
||||
writer.writerow(
|
||||
[
|
||||
row.chat_id,
|
||||
row.tool_name,
|
||||
row.count,
|
||||
row.chat_total,
|
||||
f"{row.percent_in_chat:.4f}",
|
||||
f"{row.percent_in_all:.4f}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> Namespace:
|
||||
parser = ArgumentParser(description="统计不同 chat_id 的工具使用次数和占比。")
|
||||
parser.add_argument("--db", type=Path, default=DEFAULT_DB_PATH, help=f"数据库路径,默认: {DEFAULT_DB_PATH}")
|
||||
parser.add_argument("--recent", help="统计最近多久的记录,例如: 30m、24h、7d、2w;如果同时指定 --since,则优先使用 --since")
|
||||
parser.add_argument("--since", help="仅统计此时间之后的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
|
||||
parser.add_argument("--until", help="仅统计此时间之前的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
|
||||
parser.add_argument("--min-chat-total", type=int, default=1, help="只显示工具调用总数不低于该值的 chat_id")
|
||||
parser.add_argument("--top-tools", type=int, help="每个 chat_id 最多显示前 N 个工具")
|
||||
parser.add_argument("--format", choices=("markdown", "json", "csv"), default="markdown", help="输出格式,markdown 为按 chat_id 分块的终端表")
|
||||
parser.add_argument("--include-empty-chat-id", action="store_true", help="包含 chat_id 为空的记录")
|
||||
parser.add_argument("--include-empty-tool-name", action="store_true", help="包含 tool_name 为空的记录")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
since = parse_datetime_filter(args.since) or parse_recent_filter(args.recent)
|
||||
until = parse_datetime_filter(args.until)
|
||||
min_chat_total = max(1, int(args.min_chat_total))
|
||||
top_tools = args.top_tools if args.top_tools is None else max(1, int(args.top_tools))
|
||||
|
||||
counts = fetch_tool_counts(
|
||||
db_path=args.db.resolve(),
|
||||
since=since,
|
||||
until=until,
|
||||
include_empty_chat_id=args.include_empty_chat_id,
|
||||
include_empty_tool_name=args.include_empty_tool_name,
|
||||
)
|
||||
rows = build_usage_rows(
|
||||
counts=counts,
|
||||
min_chat_total=min_chat_total,
|
||||
top_tools_per_chat=top_tools,
|
||||
)
|
||||
|
||||
if args.format == "json":
|
||||
print_json(rows)
|
||||
elif args.format == "csv":
|
||||
print_csv(rows)
|
||||
else:
|
||||
overall_rows = build_overall_rows(counts, min_chat_total=min_chat_total)
|
||||
print_markdown(rows, overall_rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
380
scripts/build_io_pairs.py
Normal file
380
scripts/build_io_pairs.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
# 确保可从任意工作目录运行:将项目根目录加入 sys.path(scripts 的上一级)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
|
||||
|
||||
SECONDS_5_MINUTES = 5 * 60
|
||||
|
||||
|
||||
def clean_output_text(text: str) -> str:
|
||||
"""
|
||||
清理输出文本,移除表情包和回复内容
|
||||
- 移除 [表情包:...] 格式的内容
|
||||
- 移除 [回复...] 格式的内容
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# 移除表情包内容:[表情包:...]
|
||||
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
|
||||
|
||||
# 移除回复内容:[回复...],说:... 的完整模式
|
||||
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
|
||||
|
||||
# 清理多余的空格和换行
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err: Optional[Exception] = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e: # noqa: BLE001
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def fetch_messages_between(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str] = None,
|
||||
) -> List[DatabaseMessages]:
|
||||
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
|
||||
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
|
||||
if platform:
|
||||
filter_query["chat_info_platform"] = platform
|
||||
# 当 limit==0 时,sort 生效,这里按时间升序
|
||||
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
|
||||
|
||||
|
||||
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
|
||||
groups: Dict[str, List[DatabaseMessages]] = {}
|
||||
for msg in messages:
|
||||
groups.setdefault(msg.chat_id, []).append(msg)
|
||||
# 保证每个分组内按时间升序
|
||||
for _chat_id, msgs in groups.items():
|
||||
msgs.sort(key=lambda m: m.time or 0)
|
||||
return groups
|
||||
|
||||
|
||||
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
|
||||
"""
|
||||
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
|
||||
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
|
||||
"""
|
||||
if not bucket:
|
||||
raise ValueError("bucket 为空,无法合并")
|
||||
|
||||
latest = bucket[-1]
|
||||
merged_texts: List[str] = []
|
||||
for m in bucket:
|
||||
text = m.processed_plain_text or ""
|
||||
if text:
|
||||
merged_texts.append(text)
|
||||
|
||||
merged = DatabaseMessages(
|
||||
# 其他信息采用最新消息
|
||||
message_id=latest.message_id,
|
||||
time=latest.time,
|
||||
chat_id=latest.chat_id,
|
||||
reply_to=latest.reply_to,
|
||||
is_mentioned=latest.is_mentioned,
|
||||
is_at=latest.is_at,
|
||||
reply_probability_boost=latest.reply_probability_boost,
|
||||
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
|
||||
priority_mode=latest.priority_mode,
|
||||
priority_info=latest.priority_info,
|
||||
additional_config=latest.additional_config,
|
||||
is_emoji=latest.is_emoji,
|
||||
is_picid=latest.is_picid,
|
||||
is_command=latest.is_command,
|
||||
is_notify=latest.is_notify,
|
||||
selected_expressions=latest.selected_expressions,
|
||||
user_id=latest.user_info.user_id,
|
||||
user_nickname=latest.user_info.user_nickname,
|
||||
user_cardname=latest.user_info.user_cardname,
|
||||
user_platform=latest.user_info.platform,
|
||||
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
|
||||
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
|
||||
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
|
||||
chat_info_user_id=latest.chat_info.user_info.user_id,
|
||||
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
|
||||
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
|
||||
chat_info_user_platform=latest.chat_info.user_info.platform,
|
||||
chat_info_stream_id=latest.chat_info.stream_id,
|
||||
chat_info_platform=latest.chat_info.platform,
|
||||
chat_info_create_time=latest.chat_info.create_time,
|
||||
chat_info_last_active_time=latest.chat_info.last_active_time,
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
|
||||
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: List[DatabaseMessages] = []
|
||||
bucket: List[DatabaseMessages] = []
|
||||
|
||||
def flush_bucket() -> None:
|
||||
nonlocal bucket
|
||||
if bucket:
|
||||
merged.append(_merge_bucket_to_message(bucket))
|
||||
bucket = []
|
||||
|
||||
for msg in messages:
|
||||
if not bucket:
|
||||
bucket = [msg]
|
||||
continue
|
||||
|
||||
last = bucket[-1]
|
||||
same_user = msg.user_info.user_id == last.user_info.user_id
|
||||
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
|
||||
|
||||
if same_user and close_enough:
|
||||
bucket.append(msg)
|
||||
else:
|
||||
flush_bucket()
|
||||
bucket = [msg]
|
||||
|
||||
flush_bucket()
|
||||
return merged
|
||||
|
||||
|
||||
def build_pairs_for_chat(
|
||||
original_messages: List[DatabaseMessages],
|
||||
merged_messages: List[DatabaseMessages],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
target_user_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
对每条合并后的消息作为 output,从其前面取 20-30 条(可配置)的原始消息作为 input。
|
||||
input 使用原始未合并的消息构建上下文。
|
||||
output 使用合并后消息的 processed_plain_text。
|
||||
如果指定了 target_user_id,则只处理该用户的消息作为 output。
|
||||
"""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
n_merged = len(merged_messages)
|
||||
n_original = len(original_messages)
|
||||
|
||||
if n_merged == 0 or n_original == 0:
|
||||
return pairs
|
||||
|
||||
# 为每个合并后的消息找到对应的原始消息位置
|
||||
merged_to_original_map = {}
|
||||
original_idx = 0
|
||||
|
||||
for merged_idx, merged_msg in enumerate(merged_messages):
|
||||
# 找到这个合并消息对应的第一个原始消息
|
||||
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
|
||||
original_idx += 1
|
||||
|
||||
# 如果找到了时间匹配的原始消息,建立映射
|
||||
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
|
||||
merged_to_original_map[merged_idx] = original_idx
|
||||
|
||||
for merged_idx in range(n_merged):
|
||||
merged_msg = merged_messages[merged_idx]
|
||||
|
||||
# 如果指定了 target_user_id,只处理该用户的消息作为 output
|
||||
if target_user_id and merged_msg.user_info.user_id != target_user_id:
|
||||
continue
|
||||
|
||||
# 找到对应的原始消息位置
|
||||
if merged_idx not in merged_to_original_map:
|
||||
continue
|
||||
|
||||
original_idx = merged_to_original_map[merged_idx]
|
||||
|
||||
# 选择上下文窗口大小
|
||||
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
|
||||
start = max(0, original_idx - window)
|
||||
context_msgs = original_messages[start:original_idx]
|
||||
|
||||
# 使用原始未合并消息构建 input
|
||||
input_str = build_readable_messages(
|
||||
messages=context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
)
|
||||
|
||||
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
|
||||
output_text = merged_msg.processed_plain_text or ""
|
||||
output_text = clean_output_text(output_text)
|
||||
output_id = merged_msg.message_id or ""
|
||||
pairs.append((input_str, output_text, output_id))
|
||||
|
||||
return pairs
|
||||
|
||||
|
||||
def build_pairs(
|
||||
start_ts: float,
|
||||
end_ts: float,
|
||||
platform: Optional[str],
|
||||
user_id: Optional[str],
|
||||
min_ctx: int,
|
||||
max_ctx: int,
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 获取所有消息(不按user_id过滤),这样input上下文可以包含所有用户的消息
|
||||
messages = fetch_messages_between(start_ts, end_ts, platform)
|
||||
groups = group_by_chat(messages)
|
||||
|
||||
all_pairs: List[Tuple[str, str, str]] = []
|
||||
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
|
||||
# 对消息进行合并,用于output
|
||||
merged = merge_adjacent_same_user(msgs)
|
||||
# 传递原始消息和合并后消息,input使用原始消息,output使用合并后消息
|
||||
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
|
||||
all_pairs.extend(pairs)
|
||||
|
||||
return all_pairs
|
||||
|
||||
|
||||
def main(argv: Optional[List[str]] = None) -> int:
|
||||
# 若未提供参数,则进入交互模式
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
if len(argv) == 0:
|
||||
return run_interactive()
|
||||
|
||||
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表,支持按用户ID筛选消息")
|
||||
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
|
||||
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
|
||||
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
|
||||
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
|
||||
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数,默认20")
|
||||
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数,默认30")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="输出保存路径,支持 .jsonl(每行 {input, output}),若不指定则打印到stdout",
|
||||
)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
start_ts = parse_datetime_to_timestamp(args.start)
|
||||
end_ts = parse_datetime_to_timestamp(args.end)
|
||||
if end_ts <= start_ts:
|
||||
raise ValueError("结束时间必须大于起始时间")
|
||||
|
||||
if args.max_ctx < args.min_ctx:
|
||||
raise ValueError("max_ctx 不能小于 min_ctx")
|
||||
|
||||
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
|
||||
|
||||
if args.output:
|
||||
# 保存为 JSONL,每行一个 {input, output, message_id}
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {args.output}")
|
||||
else:
|
||||
# 打印到 stdout
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
|
||||
suffix = f"[{default}]" if default not in (None, "") else ""
|
||||
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
|
||||
if value == "" and default is not None:
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def run_interactive() -> int:
|
||||
print("进入交互模式(直接回车采用默认值)。时间格式例如:2025-09-28 00:00:00 或 2025-09-28")
|
||||
start_str = _prompt_with_default("请输入起始时间", None)
|
||||
end_str = _prompt_with_default("请输入结束时间", None)
|
||||
platform = _prompt_with_default("平台(可留空表示不限)", "")
|
||||
user_id = _prompt_with_default("用户ID(可留空表示不限)", "")
|
||||
try:
|
||||
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
|
||||
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
|
||||
except Exception:
|
||||
print("上下文条数输入有误,使用默认 20/30")
|
||||
min_ctx, max_ctx = 20, 30
|
||||
output_path = _prompt_with_default("输出路径(.jsonl,可留空打印到控制台)", "")
|
||||
|
||||
if not start_str or not end_str:
|
||||
print("必须提供起始与结束时间。")
|
||||
return 2
|
||||
|
||||
try:
|
||||
start_ts = parse_datetime_to_timestamp(start_str)
|
||||
end_ts = parse_datetime_to_timestamp(end_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"时间解析失败:{e}")
|
||||
return 2
|
||||
|
||||
if end_ts <= start_ts:
|
||||
print("结束时间必须大于起始时间。")
|
||||
return 2
|
||||
|
||||
if max_ctx < min_ctx:
|
||||
print("最多条数不能小于最少条数。")
|
||||
return 2
|
||||
|
||||
platform_val = platform if platform != "" else None
|
||||
user_id_val = user_id if user_id != "" else None
|
||||
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
|
||||
|
||||
if output_path:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
obj = {"input": input_str, "output": output_str, "message_id": message_id}
|
||||
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
||||
print(f"已保存 {len(pairs)} 条到 {output_path}")
|
||||
else:
|
||||
for input_str, output_str, message_id in pairs:
|
||||
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
|
||||
print(f"总计 {len(pairs)} 条。")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
553
scripts/evaluate_expressions_count_analysis.py
Normal file
553
scripts/evaluate_expressions_count_analysis.py
Normal file
@@ -0,0 +1,553 @@
|
||||
"""
|
||||
表达方式按count分组的LLM评估和统计分析脚本
|
||||
|
||||
功能:
|
||||
1. 随机选择50条表达,至少要有20条count>1的项目,然后进行LLM评估
|
||||
2. 比较不同count之间的LLM评估合格率是否有显著差异
|
||||
- 首先每个count分开比较
|
||||
- 然后比较count为1和count大于1的两种
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from typing import List, Dict, Set, Tuple
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression # noqa: E402
|
||||
from src.common.database.database import db # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_count_analysis_llm")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.json")
|
||||
|
||||
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(COUNT_ANALYSIS_FILE):
|
||||
return [], set()
|
||||
|
||||
try:
|
||||
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("evaluation_results", [])
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
|
||||
logger.info(f"已加载 {len(results)} 条已有评估结果")
|
||||
return results, evaluated_pairs
|
||||
except Exception as e:
|
||||
logger.error(f"加载已有评估结果失败: {e}")
|
||||
return [], set()
|
||||
|
||||
|
||||
def save_results(evaluation_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results,
|
||||
}
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
|
||||
except Exception as e:
|
||||
logger.error(f"保存评估结果失败: {e}")
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目集合,用于避免重复
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
if evaluated_pairs is None:
|
||||
evaluated_pairs = set()
|
||||
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
# 按count分组
|
||||
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
|
||||
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
|
||||
|
||||
logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条")
|
||||
|
||||
# 选择所有count>1的项目
|
||||
selected_count_gt1 = count_gt1.copy()
|
||||
|
||||
# 选择count=1的项目,数量为count>1数量的2倍
|
||||
count_gt1_count = len(selected_count_gt1)
|
||||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(
|
||||
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||
)
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
# 随机选择count=1的项目
|
||||
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
|
||||
|
||||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(
|
||||
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
|
||||
2. 允许部分语法错误或口头化或缺省出现
|
||||
3. 表达方式不能太过特指,需要具有泛用性
|
||||
4. 一般不涉及具体的人名或名称
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(
|
||||
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||
)
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"count": expression.count,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
"""
|
||||
对评估结果进行统计分析
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
if not evaluation_results:
|
||||
print("\n没有评估结果可供分析")
|
||||
return
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("统计分析结果")
|
||||
print("=" * 60)
|
||||
|
||||
# 按count分组统计
|
||||
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
count_groups[count]["total"] += 1
|
||||
if suitable:
|
||||
count_groups[count]["suitable"] += 1
|
||||
else:
|
||||
count_groups[count]["unsuitable"] += 1
|
||||
|
||||
# 显示每个count的统计
|
||||
print("\n【按count分组统计】")
|
||||
print("-" * 60)
|
||||
for count in sorted(count_groups.keys()):
|
||||
group = count_groups[count]
|
||||
total = group["total"]
|
||||
suitable = group["suitable"]
|
||||
unsuitable = group["unsuitable"]
|
||||
pass_rate = (suitable / total * 100) if total > 0 else 0
|
||||
|
||||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 比较count=1和count>1
|
||||
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
||||
if count == 1:
|
||||
count_eq1_group["total"] += 1
|
||||
if suitable:
|
||||
count_eq1_group["suitable"] += 1
|
||||
else:
|
||||
count_eq1_group["unsuitable"] += 1
|
||||
else:
|
||||
count_gt1_group["total"] += 1
|
||||
if suitable:
|
||||
count_gt1_group["suitable"] += 1
|
||||
else:
|
||||
count_gt1_group["unsuitable"] += 1
|
||||
|
||||
print("\n【Count=1 vs Count>1 对比】")
|
||||
print("-" * 60)
|
||||
|
||||
eq1_total = count_eq1_group["total"]
|
||||
eq1_suitable = count_eq1_group["suitable"]
|
||||
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
|
||||
|
||||
gt1_total = count_gt1_group["total"]
|
||||
gt1_suitable = count_gt1_group["suitable"]
|
||||
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
|
||||
|
||||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
if eq1_total > 0 and gt1_total > 0:
|
||||
print("【统计显著性检验】")
|
||||
print("-" * 60)
|
||||
|
||||
# 构建2x2列联表
|
||||
# 通过 不通过
|
||||
# count=1 a b
|
||||
# count>1 c d
|
||||
a = eq1_suitable
|
||||
b = eq1_total - eq1_suitable
|
||||
c = gt1_suitable
|
||||
d = gt1_total - gt1_suitable
|
||||
|
||||
# 计算卡方统计量(简化版,使用Pearson卡方检验)
|
||||
n = eq1_total + gt1_total
|
||||
if n > 0:
|
||||
# 期望频数
|
||||
e_a = (eq1_total * (a + c)) / n
|
||||
e_b = (eq1_total * (b + d)) / n
|
||||
e_c = (gt1_total * (a + c)) / n
|
||||
e_d = (gt1_total * (b + d)) / n
|
||||
|
||||
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5)
|
||||
min_expected = min(e_a, e_b, e_c, e_d)
|
||||
if min_expected < 5:
|
||||
print("警告:期望频数小于5,卡方检验可能不准确")
|
||||
print("建议使用Fisher精确检验")
|
||||
|
||||
# 计算卡方值
|
||||
chi_square = 0
|
||||
if e_a > 0:
|
||||
chi_square += ((a - e_a) ** 2) / e_a
|
||||
if e_b > 0:
|
||||
chi_square += ((b - e_b) ** 2) / e_b
|
||||
if e_c > 0:
|
||||
chi_square += ((c - e_c) ** 2) / e_c
|
||||
if e_d > 0:
|
||||
chi_square += ((d - e_d) ** 2) / e_d
|
||||
|
||||
# 自由度 = (行数-1) * (列数-1) = 1
|
||||
df = 1
|
||||
|
||||
# 临界值(α=0.05)
|
||||
chi_square_critical_005 = 3.841
|
||||
chi_square_critical_001 = 6.635
|
||||
|
||||
print(f"卡方统计量: {chi_square:.4f}")
|
||||
print(f"自由度: {df}")
|
||||
print(f"临界值 (α=0.05): {chi_square_critical_005}")
|
||||
print(f"临界值 (α=0.01): {chi_square_critical_001}")
|
||||
|
||||
if chi_square >= chi_square_critical_001:
|
||||
print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)")
|
||||
elif chi_square >= chi_square_critical_005:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)")
|
||||
else:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)")
|
||||
|
||||
# 计算差异大小
|
||||
diff = abs(eq1_pass_rate - gt1_pass_rate)
|
||||
print(f"\n合格率差异: {diff:.2f}%")
|
||||
if diff > 10:
|
||||
print("差异较大(>10%)")
|
||||
elif diff > 5:
|
||||
print("差异中等(5-10%)")
|
||||
else:
|
||||
print("差异较小(<5%)")
|
||||
else:
|
||||
print("数据不足,无法进行统计检验")
|
||||
else:
|
||||
print("数据不足,无法进行count=1和count>1的对比分析")
|
||||
|
||||
# 保存统计分析结果
|
||||
analysis_result = {
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results),
|
||||
}
|
||||
|
||||
try:
|
||||
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
|
||||
with open(analysis_file, "w", encoding="utf-8") as f:
|
||||
json.dump(analysis_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n✓ 统计分析结果已保存到: {analysis_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存统计分析结果失败: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式按count分组的LLM评估和统计分析")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
evaluation_results = existing_results.copy()
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
# 检查是否需要继续评估(检查是否还有未评估的count>1项目)
|
||||
# 先查询未评估的count>1项目数量
|
||||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"查询未评估项目失败: {e}")
|
||||
has_unevaluated = False
|
||||
|
||||
if has_unevaluated:
|
||||
print("\n" + "=" * 60)
|
||||
print("开始LLM评估")
|
||||
print("=" * 60)
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
# 创建LLM实例
|
||||
print("创建LLM实例...")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm",
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
else:
|
||||
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
|
||||
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条")
|
||||
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n")
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
print(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
print(f" Situation: {expression.situation}")
|
||||
print(f" Style: {expression.style}")
|
||||
print(f" Count: {expression.count}")
|
||||
|
||||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get("error"):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
batch_results.append(llm_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
evaluation_results.extend(batch_results)
|
||||
|
||||
# 保存结果
|
||||
save_results(evaluation_results)
|
||||
else:
|
||||
print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果")
|
||||
|
||||
# 进行统计分析
|
||||
if len(evaluation_results) > 0:
|
||||
perform_statistical_analysis(evaluation_results)
|
||||
else:
|
||||
print("\n没有评估结果可供分析")
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
535
scripts/evaluate_expressions_llm_v6.py
Normal file
535
scripts/evaluate_expressions_llm_v6.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""
|
||||
表达方式LLM评估脚本
|
||||
|
||||
功能:
|
||||
1. 读取已保存的人工评估结果(作为效标)
|
||||
2. 使用LLM对相同项目进行评估
|
||||
3. 对比人工评估和LLM评估的结果,输出分析报告
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
from typing import List, Dict, Set, Tuple
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest # noqa: E402
|
||||
from src.config.config import model_config # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_llm")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
|
||||
|
||||
def load_manual_results() -> List[Dict]:
|
||||
"""
|
||||
加载人工评估结果(自动读取temp目录下所有JSON文件并合并)
|
||||
|
||||
Returns:
|
||||
人工评估结果列表(已去重)
|
||||
"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
print("\n✗ 错误:未找到temp目录")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
if not json_files:
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
print("\n✗ 错误:temp目录下未找到JSON文件")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
logger.info(f"找到 {len(json_files)} 个JSON文件")
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件:")
|
||||
for json_file in json_files:
|
||||
print(f" - {os.path.basename(json_file)}")
|
||||
|
||||
# 读取并合并所有JSON文件
|
||||
all_results = []
|
||||
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
|
||||
# 去重:使用(situation, style)作为唯一标识
|
||||
for result in results:
|
||||
if "situation" not in result or "style" not in result:
|
||||
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
|
||||
continue
|
||||
|
||||
pair = (result["situation"], result["style"])
|
||||
if pair not in seen_pairs:
|
||||
seen_pairs.add(pair)
|
||||
all_results.append(result)
|
||||
|
||||
logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {json_file} 失败: {e}")
|
||||
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
|
||||
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
|
||||
2. 允许部分语法错误或口头化或缺省出现
|
||||
3. 表达方式不能太过特指,需要具有泛用性
|
||||
4. 一般不涉及具体的人名或名称
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
return {
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按(situation, style)建立映射
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
for manual_result in manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
true_negatives += 1
|
||||
elif not manual_suitable and llm_suitable:
|
||||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (
|
||||
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
)
|
||||
recall = (
|
||||
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
)
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (
|
||||
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
)
|
||||
|
||||
# 计算人工效标的不合适率
|
||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
|
||||
|
||||
# 计算经过LLM删除后剩余项目中的不合适率
|
||||
# 在所有项目中,移除LLM判定为不合适的项目后,剩下的项目 = TP + FP(LLM判定为合适的项目)
|
||||
# 在这些剩下的项目中,按人工评定的不合适项目 = FP(人工认为不合适,但LLM认为合适)
|
||||
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数(保留的项目)
|
||||
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
|
||||
|
||||
# 两者百分比相减(评估LLM评定修正后的不合适率是否有降低)
|
||||
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
"matched": matched,
|
||||
"accuracy": accuracy,
|
||||
"accuracy_above_random": accuracy_above_random,
|
||||
"accuracy_improvement_ratio": accuracy_improvement_ratio,
|
||||
"true_positives": true_positives,
|
||||
"true_negatives": true_negatives,
|
||||
"false_positives": false_positives,
|
||||
"false_negatives": false_negatives,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1_score": f1_score,
|
||||
"specificity": specificity,
|
||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||
"rate_difference": rate_difference,
|
||||
}
|
||||
|
||||
|
||||
async def main(count: int | None = None):
|
||||
"""
|
||||
主函数
|
||||
|
||||
Args:
|
||||
count: 随机选取的数据条数,如果为None则使用全部数据
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式LLM评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 1. 加载人工评估结果
|
||||
print("\n步骤1: 加载人工评估结果")
|
||||
manual_results = load_manual_results()
|
||||
if not manual_results:
|
||||
return
|
||||
|
||||
print(f"成功加载 {len(manual_results)} 条人工评估结果")
|
||||
|
||||
# 如果指定了数量,随机选择指定数量的数据
|
||||
if count is not None:
|
||||
if count <= 0:
|
||||
print(f"\n✗ 错误:指定的数量必须大于0,当前值: {count}")
|
||||
return
|
||||
if count > len(manual_results):
|
||||
print(f"\n⚠ 警告:指定的数量 ({count}) 大于可用数据量 ({len(manual_results)}),将使用全部数据")
|
||||
else:
|
||||
random.seed() # 使用系统时间作为随机种子
|
||||
manual_results = random.sample(manual_results, count)
|
||||
print(f"随机选取 {len(manual_results)} 条数据进行评估")
|
||||
|
||||
# 验证数据完整性
|
||||
valid_manual_results = []
|
||||
for r in manual_results:
|
||||
if "situation" in r and "style" in r:
|
||||
valid_manual_results.append(r)
|
||||
else:
|
||||
logger.warning(f"跳过无效数据: {r}")
|
||||
|
||||
if len(valid_manual_results) != len(manual_results):
|
||||
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
|
||||
|
||||
print(f"有效数据: {len(valid_manual_results)} 条")
|
||||
|
||||
# 2. 创建LLM实例并评估
|
||||
print("\n步骤2: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
print("\n步骤3: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# 5. 输出FP和FN项目(在评估结果之前)
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
# 5.1 输出FP项目(人工评估不通过但LLM误判为通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
fp_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
print(f"--- [{idx}] ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
# 5.2 输出FN项目(人工评估通过但LLM误判为不通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估通过但LLM误判为不通过的项目(FN - False Negative)")
|
||||
print("=" * 60)
|
||||
|
||||
fn_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||
fn_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fn_items:
|
||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||
for idx, item in enumerate(fn_items, 1):
|
||||
print(f"--- [{idx}] ---")
|
||||
print(f"Situation: {item['situation']}")
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 通过 ✅")
|
||||
print("LLM评估: 不通过 ❌ (误删)")
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误删项目(所有人工评估通过的项目都被LLM正确识别为通过)")
|
||||
|
||||
# 6. 对比分析并输出结果
|
||||
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
# print(" 【核心能力指标】")
|
||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(
|
||||
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(
|
||||
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
|
||||
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
|
||||
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
|
||||
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
|
||||
print()
|
||||
print(" 【不合适率分析】")
|
||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||
print(
|
||||
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
|
||||
)
|
||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||
print()
|
||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(
|
||||
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
|
||||
)
|
||||
print()
|
||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
# print(f" - 含义: {'LLM删除后剩余项目中的不合适率降低了' if comparison['rate_difference'] > 0 else 'LLM删除后剩余项目中的不合适率反而升高了' if comparison['rate_difference'] < 0 else '两者相等'} ({'✓ LLM删除有效' if comparison['rate_difference'] > 0 else '✗ LLM删除效果不佳' if comparison['rate_difference'] < 0 else '效果相同'})")
|
||||
# print()
|
||||
print(" 【分类统计】")
|
||||
print(f" TP (正确识别为合适): {comparison['true_positives']}")
|
||||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
# 7. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="表达方式LLM评估脚本",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||
""",
|
||||
)
|
||||
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(count=args.count))
|
||||
275
scripts/evaluate_expressions_manual.py
Normal file
275
scripts/evaluate_expressions_manual.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
表达方式人工评估脚本
|
||||
|
||||
功能:
|
||||
1. 不停随机抽取项目(不重复)进行人工评估
|
||||
2. 将结果保存到 temp 文件夹下的 JSON 文件,作为效标(标准答案)
|
||||
3. 支持继续评估(从已有文件中读取已评估的项目,避免重复)
|
||||
"""
|
||||
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from typing import List, Dict, Set, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from src.common.database.database_model import Expression # noqa: E402
|
||||
from src.common.database.database import db # noqa: E402
|
||||
from src.common.logger import get_logger # noqa: E402
|
||||
|
||||
logger = get_logger("expression_evaluator_manual")
|
||||
|
||||
# 评估结果文件路径
|
||||
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
|
||||
|
||||
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(MANUAL_EVAL_FILE):
|
||||
return [], set()
|
||||
|
||||
try:
|
||||
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
|
||||
logger.info(f"已加载 {len(results)} 条已有评估结果")
|
||||
return results, evaluated_pairs
|
||||
except Exception as e:
|
||||
logger.error(f"加载已有评估结果失败: {e}")
|
||||
return [], set()
|
||||
|
||||
|
||||
def save_results(manual_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
Args:
|
||||
manual_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(manual_results),
|
||||
"manual_results": manual_results,
|
||||
}
|
||||
|
||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
|
||||
except Exception as e:
|
||||
logger.error(f"保存评估结果失败: {e}")
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
|
||||
"""
|
||||
获取未评估的表达方式
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目(situation, style)元组集合
|
||||
batch_size: 每次获取的数量
|
||||
|
||||
Returns:
|
||||
未评估的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.info("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
# 如果未评估数量少于请求数量,返回所有
|
||||
if len(unevaluated) <= batch_size:
|
||||
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
|
||||
return unevaluated
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(unevaluated, batch_size)
|
||||
logger.info(f"从 {len(unevaluated)} 条未评估项目中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取未评估表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
|
||||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
Returns:
|
||||
评估结果字典,如果用户退出则返回 None
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(f"人工评估 [{index}/{total}]")
|
||||
print("=" * 60)
|
||||
print(f"Situation: {expression.situation}")
|
||||
print(f"Style: {expression.style}")
|
||||
print("\n请评估该表达方式是否合适:")
|
||||
print(" 输入 'y' 或 'yes' 或 '1' 表示合适(通过)")
|
||||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
print(" 输入 's' 或 'skip' 跳过当前项目")
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||
|
||||
if user_input in ["q", "quit"]:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ["s", "skip"]:
|
||||
print("跳过当前项目")
|
||||
return "skip"
|
||||
|
||||
if user_input in ["y", "yes", "1", "是", "通过"]:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ["n", "no", "0", "否", "不通过"]:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q/s)")
|
||||
|
||||
result = {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual",
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式人工评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
logger.info("数据库连接成功")
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
manual_results = existing_results.copy()
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
batch_size = 10
|
||||
batch_count = 0
|
||||
|
||||
while True:
|
||||
# 获取未评估的项目
|
||||
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
|
||||
|
||||
if not expressions:
|
||||
print("\n" + "=" * 60)
|
||||
print("所有项目都已评估完成!")
|
||||
print("=" * 60)
|
||||
break
|
||||
|
||||
batch_count += 1
|
||||
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
|
||||
if manual_result is None:
|
||||
# 用户退出
|
||||
print("\n评估已中断")
|
||||
if batch_results:
|
||||
# 保存当前批次的结果
|
||||
manual_results.extend(batch_results)
|
||||
save_results(manual_results)
|
||||
return
|
||||
|
||||
if manual_result == "skip":
|
||||
# 跳过当前项目
|
||||
continue
|
||||
|
||||
batch_results.append(manual_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
manual_results.extend(batch_results)
|
||||
|
||||
# 保存结果
|
||||
save_results(manual_results)
|
||||
|
||||
print(f"\n当前批次完成,已评估总数: {len(manual_results)} 条")
|
||||
|
||||
# 询问是否继续
|
||||
while True:
|
||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||
if continue_input in ["y", "yes", "1", "是", "继续"]:
|
||||
break
|
||||
elif continue_input in ["n", "no", "0", "否", "退出"]:
|
||||
print("\n评估结束")
|
||||
return
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n)")
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
logger.info("数据库连接已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭数据库连接时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
81
scripts/i18n_extract_candidates.py
Normal file
81
scripts/i18n_extract_candidates.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import ast
|
||||
import re
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_EXCLUDE_PARTS = {
|
||||
".git",
|
||||
".venv",
|
||||
"dashboard",
|
||||
"docs",
|
||||
"docs-src",
|
||||
"locales",
|
||||
}
|
||||
HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]")
|
||||
|
||||
|
||||
def should_skip(path: Path) -> bool:
|
||||
return any(part in DEFAULT_EXCLUDE_PARTS for part in path.parts)
|
||||
|
||||
|
||||
def iter_python_files(root: Path) -> list[Path]:
|
||||
return sorted(path for path in root.rglob("*.py") if path.is_file() and not should_skip(path.relative_to(root)))
|
||||
|
||||
|
||||
class CandidateExtractor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
self._docstring_nodes: set[ast.AST] = set()
|
||||
self.candidates: list[tuple[int, str]] = []
|
||||
|
||||
def visit_Module(self, node: ast.Module) -> None:
|
||||
self._mark_docstring_node(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> None:
|
||||
self._mark_docstring_node(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
|
||||
self._mark_docstring_node(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
||||
self._mark_docstring_node(node)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_Constant(self, node: ast.Constant) -> None:
|
||||
if node in self._docstring_nodes:
|
||||
return
|
||||
if isinstance(node.value, str) and HAN_PATTERN.search(node.value):
|
||||
self.candidates.append((node.lineno, node.value.strip()))
|
||||
self.generic_visit(node)
|
||||
|
||||
def _mark_docstring_node(self, node: ast.Module | ast.ClassDef | ast.AsyncFunctionDef | ast.FunctionDef) -> None:
|
||||
if not node.body:
|
||||
return
|
||||
first_stmt = node.body[0]
|
||||
if isinstance(first_stmt, ast.Expr) and isinstance(first_stmt.value, ast.Constant):
|
||||
if isinstance(first_stmt.value.value, str):
|
||||
self._docstring_nodes.add(first_stmt.value)
|
||||
|
||||
|
||||
def extract_candidates(file_path: Path) -> list[tuple[int, str]]:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(source, filename=str(file_path))
|
||||
extractor = CandidateExtractor()
|
||||
extractor.visit(tree)
|
||||
return extractor.candidates
|
||||
|
||||
|
||||
def main() -> int:
|
||||
for file_path in iter_python_files(PROJECT_ROOT):
|
||||
for lineno, text in extract_candidates(file_path):
|
||||
print(f"{file_path.relative_to(PROJECT_ROOT)}:{lineno}: {text}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
411
scripts/i18n_validate.py
Normal file
411
scripts/i18n_validate.py
Normal file
@@ -0,0 +1,411 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.i18n.exceptions import ( # noqa: E402
|
||||
DuplicateTranslationKeyError,
|
||||
InvalidTranslationFileError,
|
||||
LocaleNotFoundError,
|
||||
)
|
||||
from src.common.i18n.loaders import ( # noqa: E402
|
||||
DEFAULT_LOCALE,
|
||||
PLURAL_CATEGORIES,
|
||||
TranslationValue,
|
||||
discover_locales,
|
||||
get_locales_root,
|
||||
load_locale_catalog,
|
||||
validate_translation_value,
|
||||
)
|
||||
from src.common.i18n.loaders import extract_placeholders # noqa: E402
|
||||
from src.common.prompt_i18n import ( # noqa: E402
|
||||
discover_prompt_locales,
|
||||
extract_prompt_placeholders,
|
||||
get_prompts_root,
|
||||
iter_prompt_files,
|
||||
)
|
||||
|
||||
HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
|
||||
I18NEXT_PLACEHOLDER_PATTERN = re.compile(r"\{\{\s*([^\s,}]+)(?:\s*,[^}]*)?\s*\}\}")
|
||||
DASHBOARD_DEFAULT_LOCALE = "zh"
|
||||
|
||||
|
||||
def contains_han_characters(text: str) -> bool:
|
||||
return HAN_CHARACTER_PATTERN.search(text) is not None
|
||||
|
||||
|
||||
def extract_i18next_placeholders(template: str) -> set[str]:
|
||||
placeholders: set[str] = set()
|
||||
for match in I18NEXT_PLACEHOLDER_PATTERN.finditer(template):
|
||||
placeholder_name = match.group(1)
|
||||
placeholders.add(placeholder_name.split(".", maxsplit=1)[0].split("[", maxsplit=1)[0])
|
||||
return placeholders
|
||||
|
||||
|
||||
def iter_translation_strings(value: TranslationValue) -> list[str]:
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return [value[category] for category in sorted(value.keys())]
|
||||
|
||||
|
||||
def iter_shared_translation_strings(
|
||||
source_value: TranslationValue, target_value: TranslationValue
|
||||
) -> list[tuple[str, str]]:
|
||||
if isinstance(source_value, str) or isinstance(target_value, str):
|
||||
if isinstance(source_value, str) and isinstance(target_value, str):
|
||||
return [(source_value, target_value)]
|
||||
return []
|
||||
|
||||
shared_categories = sorted(set(source_value.keys()) & set(target_value.keys()))
|
||||
return [(source_value[category], target_value[category]) for category in shared_categories]
|
||||
|
||||
|
||||
def locale_requires_latin_only_validation(locale: str) -> bool:
|
||||
normalized_locale = locale.lower()
|
||||
return normalized_locale == "en" or normalized_locale.startswith("en-")
|
||||
|
||||
|
||||
def validate_locale_content(
|
||||
key: str,
|
||||
source_value: TranslationValue,
|
||||
target_value: TranslationValue,
|
||||
locale: str,
|
||||
errors: list[str],
|
||||
locale_label: str | None = None,
|
||||
) -> None:
|
||||
resolved_locale_label = locale_label or locale
|
||||
target_texts = iter_translation_strings(target_value)
|
||||
|
||||
if any(
|
||||
source_text == target_text and contains_han_characters(source_text)
|
||||
for source_text, target_text in iter_shared_translation_strings(source_value, target_value)
|
||||
):
|
||||
errors.append(
|
||||
f"[{resolved_locale_label}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译"
|
||||
)
|
||||
|
||||
if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts):
|
||||
errors.append(f"[{resolved_locale_label}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交")
|
||||
|
||||
|
||||
def validate_translation_pair(
|
||||
key: str,
|
||||
source_value: TranslationValue,
|
||||
target_value: TranslationValue,
|
||||
locale: str,
|
||||
errors: list[str],
|
||||
placeholder_extractor: Callable[[str], set[str]] = extract_placeholders,
|
||||
locale_label: str | None = None,
|
||||
) -> None:
|
||||
resolved_locale_label = locale_label or locale
|
||||
if isinstance(source_value, str):
|
||||
if not isinstance(target_value, str):
|
||||
errors.append(
|
||||
f"[{resolved_locale_label}] key '{key}' 与 source 的类型不一致:source=string, target=plural"
|
||||
)
|
||||
return
|
||||
if placeholder_extractor(source_value) != placeholder_extractor(target_value):
|
||||
errors.append(f"[{resolved_locale_label}] key '{key}' 的占位符集合与 source 不一致")
|
||||
return
|
||||
|
||||
if not isinstance(target_value, dict):
|
||||
errors.append(f"[{resolved_locale_label}] key '{key}' 与 source 的类型不一致:source=plural, target=string")
|
||||
return
|
||||
|
||||
source_categories = set(source_value.keys())
|
||||
target_categories = set(target_value.keys())
|
||||
if source_categories != target_categories:
|
||||
errors.append(
|
||||
f"[{resolved_locale_label}] key '{key}' 的 plural category 不一致:"
|
||||
f"source={sorted(source_categories)}, target={sorted(target_categories)}"
|
||||
)
|
||||
|
||||
for category in sorted(source_categories & target_categories):
|
||||
source_placeholders = placeholder_extractor(source_value[category])
|
||||
target_placeholders = placeholder_extractor(target_value[category])
|
||||
if source_placeholders != target_placeholders:
|
||||
errors.append(
|
||||
f"[{resolved_locale_label}] key '{key}' 的 plural category '{category}' 占位符集合与 source 不一致"
|
||||
)
|
||||
|
||||
|
||||
def get_dashboard_locales_root(locales_root: Path | None = None) -> Path:
|
||||
if locales_root is not None:
|
||||
return locales_root.resolve()
|
||||
return (PROJECT_ROOT / "dashboard" / "src" / "i18n" / "locales").resolve()
|
||||
|
||||
|
||||
def discover_dashboard_locales(locales_root: Path | None = None) -> list[str]:
|
||||
root = get_dashboard_locales_root(locales_root)
|
||||
if not root.exists():
|
||||
return []
|
||||
|
||||
locale_names = [path.stem for path in root.glob("*.json") if path.is_file()]
|
||||
return sorted(locale_names)
|
||||
|
||||
|
||||
def is_plural_translation_node(value: object) -> bool:
|
||||
if not isinstance(value, dict) or not value:
|
||||
return False
|
||||
|
||||
return all(
|
||||
isinstance(category, str) and category in PLURAL_CATEGORIES and isinstance(category_value, str)
|
||||
for category, category_value in value.items()
|
||||
)
|
||||
|
||||
|
||||
def flatten_dashboard_translation_mapping(
|
||||
value: dict[str, object],
|
||||
file_path: Path,
|
||||
translations: dict[str, TranslationValue],
|
||||
parent_keys: list[str] | None = None,
|
||||
) -> None:
|
||||
current_parent_keys = parent_keys or []
|
||||
if not value:
|
||||
if current_parent_keys:
|
||||
raise InvalidTranslationFileError(
|
||||
f"{file_path} 中的 key '{'.'.join(current_parent_keys)}' 不能为空对象"
|
||||
)
|
||||
raise InvalidTranslationFileError(f"{file_path} 顶层不能为空对象")
|
||||
|
||||
for raw_key, raw_value in value.items():
|
||||
if not isinstance(raw_key, str):
|
||||
raise InvalidTranslationFileError(f"{file_path} 中存在非字符串 key")
|
||||
|
||||
normalized_key = raw_key.strip()
|
||||
if not normalized_key:
|
||||
raise InvalidTranslationFileError(f"{file_path} 中存在空字符串 key")
|
||||
|
||||
current_key_parts = [*current_parent_keys, normalized_key]
|
||||
current_key = ".".join(current_key_parts)
|
||||
|
||||
if isinstance(raw_value, str):
|
||||
if current_key in translations:
|
||||
raise DuplicateTranslationKeyError(f"{file_path} 中存在重复 key: '{current_key}'")
|
||||
translations[current_key] = raw_value
|
||||
continue
|
||||
|
||||
if is_plural_translation_node(raw_value):
|
||||
if current_key in translations:
|
||||
raise DuplicateTranslationKeyError(f"{file_path} 中存在重复 key: '{current_key}'")
|
||||
translations[current_key] = validate_translation_value(current_key, raw_value, file_path)
|
||||
continue
|
||||
|
||||
if isinstance(raw_value, dict):
|
||||
flatten_dashboard_translation_mapping(raw_value, file_path, translations, current_key_parts)
|
||||
continue
|
||||
|
||||
raise InvalidTranslationFileError(f"{file_path} 中的 key '{current_key}' 必须是字符串或对象")
|
||||
|
||||
|
||||
def load_dashboard_translation_file(file_path: Path) -> dict[str, TranslationValue]:
|
||||
try:
|
||||
raw_payload = json.loads(file_path.read_text(encoding="utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise InvalidTranslationFileError(f"{file_path} 不是合法 JSON: {exc}") from exc
|
||||
|
||||
if not isinstance(raw_payload, dict):
|
||||
raise InvalidTranslationFileError(f"{file_path} 顶层必须是 JSON object")
|
||||
|
||||
translations: dict[str, TranslationValue] = {}
|
||||
flatten_dashboard_translation_mapping(raw_payload, file_path, translations)
|
||||
return translations
|
||||
|
||||
|
||||
def load_dashboard_locale_catalog(
|
||||
locale: str,
|
||||
locales_root: Path | None = None,
|
||||
) -> dict[str, TranslationValue]:
|
||||
locale_file = get_dashboard_locales_root(locales_root) / f"{locale}.json"
|
||||
if not locale_file.exists():
|
||||
raise LocaleNotFoundError(f"未找到 locale 文件: {locale_file}")
|
||||
|
||||
return load_dashboard_translation_file(locale_file)
|
||||
|
||||
|
||||
def validate_dashboard_json_locales(locales_root: Path | None = None) -> list[str]:
|
||||
resolved_locales_root = get_dashboard_locales_root(locales_root)
|
||||
locales = discover_dashboard_locales(resolved_locales_root)
|
||||
errors: list[str] = []
|
||||
|
||||
if DASHBOARD_DEFAULT_LOCALE not in locales:
|
||||
errors.append(f"[dashboard] 缺少默认 locale 文件: {DASHBOARD_DEFAULT_LOCALE}.json")
|
||||
return errors
|
||||
|
||||
catalogs: dict[str, dict[str, TranslationValue]] = {}
|
||||
for locale in locales:
|
||||
try:
|
||||
catalogs[locale] = load_dashboard_locale_catalog(locale, resolved_locales_root)
|
||||
except Exception as exc:
|
||||
errors.append(f"[dashboard:{locale}] 加载失败: {exc}")
|
||||
|
||||
source_catalog = catalogs.get(DASHBOARD_DEFAULT_LOCALE)
|
||||
if source_catalog is None:
|
||||
return errors
|
||||
|
||||
source_keys = set(source_catalog.keys())
|
||||
for locale, catalog in catalogs.items():
|
||||
if locale == DASHBOARD_DEFAULT_LOCALE:
|
||||
continue
|
||||
|
||||
locale_label = f"dashboard:{locale}"
|
||||
locale_keys = set(catalog.keys())
|
||||
for key in sorted(source_keys - locale_keys):
|
||||
errors.append(f"[{locale_label}] 缺少 key: {key}")
|
||||
for key in sorted(locale_keys - source_keys):
|
||||
errors.append(f"[{locale_label}] 存在多余 key: {key}")
|
||||
|
||||
for key in sorted(source_keys & locale_keys):
|
||||
source_value = source_catalog[key]
|
||||
target_value = catalog[key]
|
||||
validate_translation_pair(
|
||||
key,
|
||||
source_value,
|
||||
target_value,
|
||||
locale,
|
||||
errors,
|
||||
placeholder_extractor=extract_i18next_placeholders,
|
||||
locale_label=locale_label,
|
||||
)
|
||||
if isinstance(source_value, str) == isinstance(target_value, str):
|
||||
validate_locale_content(key, source_value, target_value, locale, errors, locale_label=locale_label)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_json_locales(locales_root: Path | None = None) -> list[str]:
|
||||
resolved_locales_root = get_locales_root(locales_root)
|
||||
locales = discover_locales(resolved_locales_root)
|
||||
errors: list[str] = []
|
||||
|
||||
if DEFAULT_LOCALE not in locales:
|
||||
errors.append(f"缺少默认 locale 目录: {DEFAULT_LOCALE}")
|
||||
return errors
|
||||
|
||||
catalogs: dict[str, dict[str, TranslationValue]] = {}
|
||||
for locale in locales:
|
||||
try:
|
||||
catalogs[locale] = load_locale_catalog(locale, resolved_locales_root)
|
||||
except Exception as exc:
|
||||
errors.append(f"[{locale}] 加载失败: {exc}")
|
||||
|
||||
source_catalog = catalogs.get(DEFAULT_LOCALE)
|
||||
if source_catalog is None:
|
||||
return errors
|
||||
|
||||
source_keys = set(source_catalog.keys())
|
||||
for locale, catalog in catalogs.items():
|
||||
if locale == DEFAULT_LOCALE:
|
||||
continue
|
||||
|
||||
locale_keys = set(catalog.keys())
|
||||
for key in sorted(source_keys - locale_keys):
|
||||
errors.append(f"[{locale}] 缺少 key: {key}")
|
||||
for key in sorted(locale_keys - source_keys):
|
||||
errors.append(f"[{locale}] 存在多余 key: {key}")
|
||||
|
||||
for key in sorted(source_keys & locale_keys):
|
||||
source_value = source_catalog[key]
|
||||
target_value = catalog[key]
|
||||
validate_translation_pair(key, source_value, target_value, locale, errors)
|
||||
if isinstance(source_value, str) == isinstance(target_value, str):
|
||||
validate_locale_content(key, source_value, target_value, locale, errors)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def build_prompt_catalog(locale_dir: Path) -> dict[Path, Path]:
|
||||
return {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)}
|
||||
|
||||
|
||||
def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]:
|
||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||
prompt_locales = set(discover_prompt_locales(resolved_prompts_root))
|
||||
known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE]
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
if DEFAULT_LOCALE not in prompt_locales:
|
||||
errors.append(f"缺少默认 Prompt locale 目录: {DEFAULT_LOCALE}")
|
||||
return errors, warnings
|
||||
|
||||
source_dir = resolved_prompts_root / DEFAULT_LOCALE
|
||||
source_files = build_prompt_catalog(source_dir)
|
||||
source_relative_paths = set(source_files.keys())
|
||||
|
||||
for locale in known_locales:
|
||||
locale_dir = resolved_prompts_root / locale
|
||||
if not locale_dir.exists():
|
||||
warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}")
|
||||
continue
|
||||
|
||||
locale_files = build_prompt_catalog(locale_dir)
|
||||
locale_relative_paths = set(locale_files.keys())
|
||||
|
||||
for relative_path in sorted(source_relative_paths - locale_relative_paths):
|
||||
warnings.append(f"[prompt:{locale}] 缺少模板: {relative_path.as_posix()},运行时将回退到 {DEFAULT_LOCALE}")
|
||||
|
||||
for relative_path in sorted(locale_relative_paths - source_relative_paths):
|
||||
warnings.append(f"[prompt:{locale}] 存在额外模板: {relative_path.as_posix()}")
|
||||
|
||||
for relative_path in sorted(source_relative_paths & locale_relative_paths):
|
||||
source_text = source_files[relative_path].read_text(encoding="utf-8")
|
||||
locale_text = locale_files[relative_path].read_text(encoding="utf-8")
|
||||
|
||||
source_placeholders = extract_prompt_placeholders(source_text)
|
||||
locale_placeholders = extract_prompt_placeholders(locale_text)
|
||||
if source_placeholders != locale_placeholders:
|
||||
errors.append(
|
||||
"[prompt:{locale}] 模板 '{path}' 的占位符集合与 source 不一致:"
|
||||
"source={source_placeholders}, target={target_placeholders}".format(
|
||||
locale=locale,
|
||||
path=relative_path.as_posix(),
|
||||
source_placeholders=sorted(source_placeholders),
|
||||
target_placeholders=sorted(locale_placeholders),
|
||||
)
|
||||
)
|
||||
|
||||
if source_text == locale_text:
|
||||
warnings.append(f"[prompt:{locale}] 模板 '{relative_path.as_posix()}' 与 source 完全相同,可能尚未翻译")
|
||||
|
||||
return errors, warnings
|
||||
|
||||
|
||||
def _print_warnings(warnings: list[str]) -> None:
|
||||
if not warnings:
|
||||
return
|
||||
print(f"warnings ({len(warnings)}):")
|
||||
for warning in warnings[:10]:
|
||||
print(f" - {warning}")
|
||||
if len(warnings) > 10:
|
||||
print(f" - ... 另外还有 {len(warnings) - 10} 条 warning")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
errors = validate_json_locales()
|
||||
errors.extend(validate_dashboard_json_locales())
|
||||
prompt_errors, prompt_warnings = validate_prompt_templates()
|
||||
errors.extend(prompt_errors)
|
||||
|
||||
if errors:
|
||||
print("i18n validation failed:")
|
||||
for error in errors:
|
||||
print(f" - {error}")
|
||||
_print_warnings(prompt_warnings)
|
||||
return 1
|
||||
|
||||
print("i18n validation passed.")
|
||||
_print_warnings(prompt_warnings)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
40
scripts/make_scripts/generate_requirements.py
Normal file
40
scripts/make_scripts/generate_requirements.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import tomlkit
|
||||
|
||||
|
||||
def generate_requirements(pyproject_path="pyproject.toml", output_path="requirements.txt"):
|
||||
try:
|
||||
# 读取 pyproject.toml 文件
|
||||
with open(pyproject_path, "r", encoding="utf-8") as file:
|
||||
pyproject_data = tomlkit.load(file)
|
||||
|
||||
# 获取 pyproject.toml 中的 dependencies 列表
|
||||
pyproject_dependencies = pyproject_data.get("project", {}).get("dependencies", [])
|
||||
if not pyproject_dependencies:
|
||||
print("未找到 dependencies 部分,无法生成 requirements.txt")
|
||||
return
|
||||
|
||||
# 读取 requirements.txt 文件
|
||||
try:
|
||||
with open(output_path, "r", encoding="utf-8") as file:
|
||||
requirements = {line.strip() for line in file if line.strip()}
|
||||
except FileNotFoundError:
|
||||
requirements = set()
|
||||
|
||||
if extra_dependencies := requirements - set(pyproject_dependencies):
|
||||
print("警告: 以下依赖项存在于 requirements.txt 中,但未在 pyproject.toml 中找到:")
|
||||
for dep in extra_dependencies:
|
||||
print(f" - {dep}")
|
||||
|
||||
# 写入更新后的 requirements.txt 文件
|
||||
with open(output_path, "w", encoding="utf-8") as file:
|
||||
file.write("\n".join(pyproject_dependencies))
|
||||
|
||||
print(f"requirements.txt 文件已生成: {output_path}")
|
||||
except FileNotFoundError:
|
||||
print(f"未找到 {pyproject_path} 文件,请检查路径是否正确。")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_requirements()
|
||||
1132
scripts/mmipkg_tool.py
Normal file
1132
scripts/mmipkg_tool.py
Normal file
File diff suppressed because it is too large
Load Diff
2532
scripts/preview_reply_effect_scores.py
Normal file
2532
scripts/preview_reply_effect_scores.py
Normal file
File diff suppressed because it is too large
Load Diff
973
scripts/run.sh
Normal file
973
scripts/run.sh
Normal file
@@ -0,0 +1,973 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
|
||||
# 适用于macOS/Arch/Ubuntu 24.10/Debian 12/CentOS 9
|
||||
# 请小心使用任何一键脚本!
|
||||
|
||||
INSTALLER_VERSION="0.0.5-refactor"
|
||||
LANG=C.UTF-8
|
||||
|
||||
# 如无法访问GitHub请修改此处镜像地址
|
||||
GITHUB_REPO="https://ghfast.top/https://github.com"
|
||||
|
||||
# 颜色输出
|
||||
GREEN="\e[32m"
|
||||
RED="\e[31m"
|
||||
RESET="\e[0m"
|
||||
|
||||
# 需要的基本软件包(兼容 Bash 3,避免使用关联数组)
|
||||
REQUIRED_PACKAGES_COMMON="git sudo python3 curl gnupg"
|
||||
REQUIRED_PACKAGES_DEBIAN="python3-venv python3-pip build-essential"
|
||||
REQUIRED_PACKAGES_UBUNTU="python3-venv python3-pip build-essential"
|
||||
REQUIRED_PACKAGES_CENTOS="epel-release python3-pip python3-devel gcc gcc-c++ make"
|
||||
REQUIRED_PACKAGES_ARCH="python-virtualenv python-pip base-devel"
|
||||
REQUIRED_PACKAGES_MACOS="git gnupg python"
|
||||
|
||||
# 服务名称
|
||||
SERVICE_NAME="maicore"
|
||||
SERVICE_NAME_WEB="maicore-web"
|
||||
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
|
||||
|
||||
SERVICE_USER="${SUDO_USER:-$USER}"
|
||||
SERVICE_HOME="$(eval echo "~${SERVICE_USER}" 2>/dev/null)"
|
||||
if [[ -z "$SERVICE_HOME" || "$SERVICE_HOME" == "~${SERVICE_USER}" ]]; then
|
||||
SERVICE_HOME="$HOME"
|
||||
fi
|
||||
|
||||
IS_MACOS=false
|
||||
[[ "$(uname -s)" == "Darwin" ]] && IS_MACOS=true
|
||||
|
||||
INSTALL_CONF="/etc/maicore_install.conf"
|
||||
|
||||
# 默认项目目录
|
||||
DEFAULT_INSTALL_DIR="/opt/maicore"
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
DEFAULT_INSTALL_DIR="${SERVICE_HOME}/maicore"
|
||||
INSTALL_CONF="${SERVICE_HOME}/.config/maicore/maicore_install.conf"
|
||||
fi
|
||||
|
||||
LAUNCHD_DOMAIN=""
|
||||
LAUNCHD_AGENT_DIR=""
|
||||
LAUNCHD_LABEL_MAIN="com.maicore.${SERVICE_NAME}"
|
||||
LAUNCHD_LABEL_NBADAPTER="com.maicore.${SERVICE_NAME_NBADAPTER}"
|
||||
LAUNCHD_PLIST_MAIN=""
|
||||
LAUNCHD_PLIST_NBADAPTER=""
|
||||
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
SERVICE_UID="$(id -u "${SERVICE_USER}" 2>/dev/null || id -u)"
|
||||
LAUNCHD_DOMAIN="gui/${SERVICE_UID}"
|
||||
LAUNCHD_AGENT_DIR="${SERVICE_HOME}/Library/LaunchAgents"
|
||||
LAUNCHD_PLIST_MAIN="${LAUNCHD_AGENT_DIR}/${LAUNCHD_LABEL_MAIN}.plist"
|
||||
LAUNCHD_PLIST_NBADAPTER="${LAUNCHD_AGENT_DIR}/${LAUNCHD_LABEL_NBADAPTER}.plist"
|
||||
fi
|
||||
|
||||
get_required_packages() {
|
||||
local distro="$1"
|
||||
case "$distro" in
|
||||
debian)
|
||||
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_DEBIAN}"
|
||||
;;
|
||||
ubuntu)
|
||||
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_UBUNTU}"
|
||||
;;
|
||||
centos)
|
||||
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_CENTOS}"
|
||||
;;
|
||||
arch)
|
||||
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_ARCH}"
|
||||
;;
|
||||
macos)
|
||||
echo "${REQUIRED_PACKAGES_MACOS}"
|
||||
;;
|
||||
*)
|
||||
echo "${REQUIRED_PACKAGES_COMMON}"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
IS_INSTALL_NAPCAT=false
|
||||
IS_INSTALL_DEPENDENCIES=false
|
||||
|
||||
resolve_brew_bin() {
|
||||
local brew_bin
|
||||
brew_bin="$(command -v brew)"
|
||||
[[ -z "$brew_bin" && -x /opt/homebrew/bin/brew ]] && brew_bin="/opt/homebrew/bin/brew"
|
||||
[[ -z "$brew_bin" && -x /usr/local/bin/brew ]] && brew_bin="/usr/local/bin/brew"
|
||||
[[ -n "$brew_bin" ]] && echo "$brew_bin"
|
||||
}
|
||||
|
||||
run_brew() {
|
||||
local brew_bin
|
||||
brew_bin="$(resolve_brew_bin)"
|
||||
|
||||
[[ -z "$brew_bin" ]] && return 1
|
||||
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
|
||||
sudo -u "${SUDO_USER}" "${brew_bin}" "$@"
|
||||
else
|
||||
"${brew_bin}" "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
run_launchctl() {
|
||||
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
|
||||
sudo -u "${SUDO_USER}" launchctl "$@"
|
||||
else
|
||||
launchctl "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_writable_parent() {
|
||||
local path="$1"
|
||||
local parent
|
||||
parent="$(dirname "$path")"
|
||||
mkdir -p "$parent"
|
||||
if [[ "$IS_MACOS" == true && "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" ]]; then
|
||||
chown "${SUDO_USER}" "$parent" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
save_install_info() {
|
||||
ensure_writable_parent "$INSTALL_CONF"
|
||||
cat > "$INSTALL_CONF" <<EOF
|
||||
INSTALLER_VERSION=${INSTALLER_VERSION}
|
||||
INSTALL_DIR=${INSTALL_DIR}
|
||||
BRANCH=${BRANCH}
|
||||
EOF
|
||||
}
|
||||
|
||||
compute_md5() {
|
||||
local file_path="$1"
|
||||
|
||||
if command -v md5sum &>/dev/null; then
|
||||
md5sum "$file_path" | awk '{print $1}'
|
||||
elif command -v md5 &>/dev/null; then
|
||||
md5 -q "$file_path"
|
||||
else
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
launchd_label_for_service() {
|
||||
local service="$1"
|
||||
case "$service" in
|
||||
${SERVICE_NAME})
|
||||
echo "$LAUNCHD_LABEL_MAIN"
|
||||
;;
|
||||
${SERVICE_NAME_NBADAPTER})
|
||||
echo "$LAUNCHD_LABEL_NBADAPTER"
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
launchd_plist_for_service() {
|
||||
local service="$1"
|
||||
case "$service" in
|
||||
${SERVICE_NAME})
|
||||
echo "$LAUNCHD_PLIST_MAIN"
|
||||
;;
|
||||
${SERVICE_NAME_NBADAPTER})
|
||||
echo "$LAUNCHD_PLIST_NBADAPTER"
|
||||
;;
|
||||
*)
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
is_launchd_service_loaded() {
|
||||
local service="$1"
|
||||
local label
|
||||
label="$(launchd_label_for_service "$service")" || return 1
|
||||
run_launchctl print "${LAUNCHD_DOMAIN}/${label}" &>/dev/null
|
||||
}
|
||||
|
||||
start_service() {
|
||||
local service="$1"
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
local label
|
||||
local plist
|
||||
label="$(launchd_label_for_service "$service")" || return 1
|
||||
plist="$(launchd_plist_for_service "$service")" || return 1
|
||||
if [[ ! -f "$plist" && -d "${INSTALL_DIR}/MaiBot" ]]; then
|
||||
create_launchd_services
|
||||
fi
|
||||
if [[ ! -f "$plist" ]]; then
|
||||
echo -e "${RED}未找到服务配置文件:${plist}${RESET}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
if is_launchd_service_loaded "$service"; then
|
||||
run_launchctl kickstart -k "${LAUNCHD_DOMAIN}/${label}"
|
||||
else
|
||||
run_launchctl bootstrap "${LAUNCHD_DOMAIN}" "$plist"
|
||||
fi
|
||||
else
|
||||
systemctl start "$service"
|
||||
fi
|
||||
}
|
||||
|
||||
stop_service() {
|
||||
local service="$1"
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
local label
|
||||
label="$(launchd_label_for_service "$service")" || return 1
|
||||
if is_launchd_service_loaded "$service"; then
|
||||
run_launchctl bootout "${LAUNCHD_DOMAIN}/${label}"
|
||||
fi
|
||||
else
|
||||
systemctl stop "$service"
|
||||
fi
|
||||
}
|
||||
|
||||
restart_service() {
|
||||
local service="$1"
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
stop_service "$service"
|
||||
start_service "$service"
|
||||
else
|
||||
systemctl restart "$service"
|
||||
fi
|
||||
}
|
||||
|
||||
create_launchd_services() {
|
||||
mkdir -p "${LAUNCHD_AGENT_DIR}"
|
||||
mkdir -p "${INSTALL_DIR}/logs"
|
||||
|
||||
cat > "${LAUNCHD_PLIST_MAIN}" <<EOF
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>${LAUNCHD_LABEL_MAIN}</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>${INSTALL_DIR}/venv/bin/python3</string>
|
||||
<string>bot.py</string>
|
||||
</array>
|
||||
<key>WorkingDirectory</key>
|
||||
<string>${INSTALL_DIR}/MaiBot</string>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>${INSTALL_DIR}/logs/${SERVICE_NAME}.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>${INSTALL_DIR}/logs/${SERVICE_NAME}.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
cat > "${LAUNCHD_PLIST_NBADAPTER}" <<EOF
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>${LAUNCHD_LABEL_NBADAPTER}</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>${INSTALL_DIR}/venv/bin/python3</string>
|
||||
<string>main.py</string>
|
||||
</array>
|
||||
<key>WorkingDirectory</key>
|
||||
<string>${INSTALL_DIR}/MaiBot-Napcat-Adapter</string>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>${INSTALL_DIR}/logs/${SERVICE_NAME_NBADAPTER}.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>${INSTALL_DIR}/logs/${SERVICE_NAME_NBADAPTER}.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
|
||||
chown "${SUDO_USER}" "${LAUNCHD_PLIST_MAIN}" "${LAUNCHD_PLIST_NBADAPTER}" "${LAUNCHD_AGENT_DIR}" 2>/dev/null || true
|
||||
fi
|
||||
}
|
||||
|
||||
# 检查是否已安装
|
||||
check_installed() {
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
[[ -f "$INSTALL_CONF" ]]
|
||||
else
|
||||
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
|
||||
fi
|
||||
}
|
||||
|
||||
# 加载安装信息
|
||||
load_install_info() {
|
||||
if [[ -f "$INSTALL_CONF" ]]; then
|
||||
source "$INSTALL_CONF"
|
||||
else
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
BRANCH="refactor"
|
||||
fi
|
||||
}
|
||||
|
||||
# 显示管理菜单
|
||||
show_menu() {
|
||||
while true; do
|
||||
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
|
||||
"1" "启动MaiCore" \
|
||||
"2" "停止MaiCore" \
|
||||
"3" "重启MaiCore" \
|
||||
"4" "启动NapCat Adapter" \
|
||||
"5" "停止NapCat Adapter" \
|
||||
"6" "重启NapCat Adapter" \
|
||||
"7" "拉取最新MaiCore仓库" \
|
||||
"8" "切换分支" \
|
||||
"9" "退出" 3>&1 1>&2 2>&3)
|
||||
|
||||
[[ $? -ne 0 ]] && exit 0
|
||||
|
||||
case "$choice" in
|
||||
1)
|
||||
start_service "${SERVICE_NAME}"
|
||||
whiptail --msgbox "✅MaiCore已启动" 10 60
|
||||
;;
|
||||
2)
|
||||
stop_service "${SERVICE_NAME}"
|
||||
whiptail --msgbox "🛑MaiCore已停止" 10 60
|
||||
;;
|
||||
3)
|
||||
restart_service "${SERVICE_NAME}"
|
||||
whiptail --msgbox "🔄MaiCore已重启" 10 60
|
||||
;;
|
||||
4)
|
||||
start_service "${SERVICE_NAME_NBADAPTER}"
|
||||
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
|
||||
;;
|
||||
5)
|
||||
stop_service "${SERVICE_NAME_NBADAPTER}"
|
||||
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
|
||||
;;
|
||||
6)
|
||||
restart_service "${SERVICE_NAME_NBADAPTER}"
|
||||
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
|
||||
;;
|
||||
7)
|
||||
update_dependencies
|
||||
;;
|
||||
8)
|
||||
switch_branch
|
||||
;;
|
||||
9)
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
whiptail --msgbox "无效选项!" 10 60
|
||||
;;
|
||||
esac
|
||||
done
|
||||
}
|
||||
|
||||
# 更新依赖
|
||||
update_dependencies() {
|
||||
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
|
||||
stop_service "${SERVICE_NAME}"
|
||||
cd "${INSTALL_DIR}/MaiBot" || {
|
||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
||||
return 1
|
||||
}
|
||||
if ! git pull origin "${BRANCH}"; then
|
||||
whiptail --msgbox "🚫 代码更新失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
source "${INSTALL_DIR}/venv/bin/activate"
|
||||
if ! pip install -r requirements.txt; then
|
||||
whiptail --msgbox "🚫 依赖安装失败!" 10 60
|
||||
deactivate
|
||||
return 1
|
||||
fi
|
||||
deactivate
|
||||
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
|
||||
}
|
||||
|
||||
# 切换分支
|
||||
switch_branch() {
|
||||
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
|
||||
[[ -z "$new_branch" ]] && {
|
||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
||||
return 1
|
||||
}
|
||||
|
||||
cd "${INSTALL_DIR}/MaiBot" || {
|
||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
||||
return 1
|
||||
}
|
||||
|
||||
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
|
||||
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! git checkout "${new_branch}"; then
|
||||
whiptail --msgbox "🚫 分支切换失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
|
||||
if ! git pull origin "${new_branch}"; then
|
||||
whiptail --msgbox "🚫 代码拉取失败!" 10 60
|
||||
return 1
|
||||
fi
|
||||
stop_service "${SERVICE_NAME}"
|
||||
source "${INSTALL_DIR}/venv/bin/activate"
|
||||
pip install -r requirements.txt
|
||||
deactivate
|
||||
|
||||
BRANCH="${new_branch}"
|
||||
save_install_info
|
||||
check_eula
|
||||
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} !" 10 60
|
||||
}
|
||||
|
||||
check_eula() {
|
||||
# 首先计算当前EULA的MD5值
|
||||
current_md5=$(compute_md5 "${INSTALL_DIR}/MaiBot/EULA.md")
|
||||
|
||||
# 首先计算当前隐私条款文件的哈希值
|
||||
current_md5_privacy=$(compute_md5 "${INSTALL_DIR}/MaiBot/PRIVACY.md")
|
||||
|
||||
# 如果当前的md5值为空,则直接返回
|
||||
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
|
||||
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
|
||||
fi
|
||||
|
||||
# 检查eula.confirmed文件是否存在
|
||||
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
|
||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
||||
confirmed_md5=$(cat "${INSTALL_DIR}/MaiBot/eula.confirmed")
|
||||
else
|
||||
confirmed_md5=""
|
||||
fi
|
||||
|
||||
# 检查privacy.confirmed文件是否存在
|
||||
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
|
||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
||||
confirmed_md5_privacy=$(cat "${INSTALL_DIR}/MaiBot/privacy.confirmed")
|
||||
else
|
||||
confirmed_md5_privacy=""
|
||||
fi
|
||||
|
||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
||||
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
|
||||
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议? \n\n " 12 70
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo -n "$current_md5" > "${INSTALL_DIR}/MaiBot/eula.confirmed"
|
||||
echo -n "$current_md5_privacy" > "${INSTALL_DIR}/MaiBot/privacy.confirmed"
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
}
|
||||
|
||||
# 测速并选择PyPI源(仅当阿里云更快时使用阿里云)
|
||||
measure_url_latency() {
|
||||
local url="$1"
|
||||
local latency
|
||||
|
||||
latency=$(curl -sS -o /dev/null -w "%{time_total}" --connect-timeout 3 --max-time 8 "$url" 2>/dev/null)
|
||||
|
||||
if [[ $? -eq 0 && "$latency" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
|
||||
echo "$latency"
|
||||
return 0
|
||||
else
|
||||
echo "999999"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
resolve_default_pypi_index_url() {
|
||||
local default_url=""
|
||||
|
||||
if [[ -n "${PIP_INDEX_URL:-}" ]]; then
|
||||
default_url="$PIP_INDEX_URL"
|
||||
elif [[ -n "${UV_INDEX_URL:-}" ]]; then
|
||||
default_url="$UV_INDEX_URL"
|
||||
elif command -v pip &>/dev/null; then
|
||||
default_url=$(pip config get global.index-url 2>/dev/null | head -n 1)
|
||||
if [[ -z "$default_url" ]]; then
|
||||
default_url=$(pip config get install.index-url 2>/dev/null | head -n 1)
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ -z "$default_url" ]]; then
|
||||
default_url="https://pypi.org/simple"
|
||||
fi
|
||||
|
||||
echo "$default_url"
|
||||
}
|
||||
|
||||
select_pypi_index_url() {
|
||||
local default_url
|
||||
local aliyun_url="https://mirrors.aliyun.com/pypi/simple"
|
||||
local default_latency
|
||||
local aliyun_latency
|
||||
local default_status
|
||||
local aliyun_status
|
||||
|
||||
default_url=$(resolve_default_pypi_index_url)
|
||||
default_latency=$(measure_url_latency "$default_url")
|
||||
default_status=$?
|
||||
aliyun_latency=$(measure_url_latency "$aliyun_url")
|
||||
aliyun_status=$?
|
||||
|
||||
if [[ $aliyun_status -eq 0 && $default_status -ne 0 ]]; then
|
||||
PYPI_INDEX_URL="$aliyun_url"
|
||||
PYPI_INDEX_NAME="阿里云镜像(默认源测速失败)"
|
||||
UV_PIP_INDEX_OPTION=(-i "$aliyun_url")
|
||||
echo -e "${RED}默认源测速失败,已选择${PYPI_INDEX_NAME}:${PYPI_INDEX_URL}${RESET}"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ $aliyun_status -ne 0 && $default_status -eq 0 ]]; then
|
||||
PYPI_INDEX_URL="$default_url"
|
||||
PYPI_INDEX_NAME="默认源(阿里云测速失败)"
|
||||
UV_PIP_INDEX_OPTION=()
|
||||
echo -e "${RED}阿里云测速失败,已选择${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ $aliyun_status -ne 0 && $default_status -ne 0 ]]; then
|
||||
PYPI_INDEX_URL="$default_url"
|
||||
PYPI_INDEX_NAME="默认源(双源测速失败)"
|
||||
UV_PIP_INDEX_OPTION=()
|
||||
echo -e "${RED}默认源和阿里云测速均失败,回退到${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
|
||||
return
|
||||
fi
|
||||
|
||||
if awk "BEGIN {exit !(${aliyun_latency} < ${default_latency})}"; then
|
||||
PYPI_INDEX_URL="$aliyun_url"
|
||||
PYPI_INDEX_NAME="阿里云镜像"
|
||||
UV_PIP_INDEX_OPTION=(-i "$aliyun_url")
|
||||
else
|
||||
PYPI_INDEX_URL="$default_url"
|
||||
PYPI_INDEX_NAME="默认源"
|
||||
UV_PIP_INDEX_OPTION=()
|
||||
fi
|
||||
|
||||
if [[ ${#UV_PIP_INDEX_OPTION[@]} -gt 0 ]]; then
|
||||
echo -e "${GREEN}已选择${PYPI_INDEX_NAME}:${PYPI_INDEX_URL}${RESET}"
|
||||
else
|
||||
echo -e "${GREEN}已选择${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
|
||||
fi
|
||||
}
|
||||
|
||||
# ----------- 主安装流程 -----------
|
||||
run_installation() {
|
||||
# 1/6: 检测是否安装 whiptail
|
||||
if ! command -v whiptail &>/dev/null; then
|
||||
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
|
||||
|
||||
if command -v apt-get &>/dev/null; then
|
||||
apt-get update && apt-get install -y whiptail
|
||||
elif command -v pacman &>/dev/null; then
|
||||
pacman -Syu --noconfirm whiptail
|
||||
elif command -v yum &>/dev/null; then
|
||||
yum install -y whiptail
|
||||
elif command -v brew &>/dev/null || [[ -x /opt/homebrew/bin/brew ]] || [[ -x /usr/local/bin/brew ]]; then
|
||||
run_brew install newt
|
||||
|
||||
# 确保当前 shell 能找到 Homebrew 安装的 whiptail。
|
||||
[[ -x /opt/homebrew/bin/whiptail ]] && export PATH="/opt/homebrew/bin:${PATH}"
|
||||
[[ -x /usr/local/bin/whiptail ]] && export PATH="/usr/local/bin:${PATH}"
|
||||
else
|
||||
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v whiptail &>/dev/null; then
|
||||
echo -e "${RED}[Error] whiptail 安装失败或不可用,请手动安装后重试。${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
whiptail --title "ℹ️ 提示" --msgbox "如果您没有特殊需求,请优先使用docker方式部署。" 10 60
|
||||
|
||||
# 协议确认
|
||||
if ! (whiptail --title "ℹ️ [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议?" 12 70); then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 欢迎信息
|
||||
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore,将自动进入安装流程,安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段,代码可能随时更改\n文档未完善,有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险,请自行了解,谨慎使用\n由于持续迭代,可能存在一些已知或未知的bug\n由于开发中,可能消耗较多token\n\n本脚本可能更新不及时,如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
|
||||
|
||||
# 系统检查
|
||||
check_system() {
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
ID="macos"
|
||||
VERSION_ID="$(sw_vers -productVersion 2>/dev/null)"
|
||||
PRETTY_NAME="macOS ${VERSION_ID}"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ "$(id -u)" -ne 0 ]]; then
|
||||
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -f /etc/os-release ]]; then
|
||||
source /etc/os-release
|
||||
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
|
||||
return
|
||||
elif [[ "$ID" == "arch" ]]; then
|
||||
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
||||
return
|
||||
else
|
||||
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
check_system
|
||||
|
||||
# 设置包管理器
|
||||
case "$ID" in
|
||||
debian|ubuntu)
|
||||
PKG_MANAGER="apt"
|
||||
;;
|
||||
centos)
|
||||
PKG_MANAGER="yum"
|
||||
;;
|
||||
arch)
|
||||
# 添加arch包管理器
|
||||
PKG_MANAGER="pacman"
|
||||
;;
|
||||
macos)
|
||||
PKG_MANAGER="brew"
|
||||
;;
|
||||
esac
|
||||
|
||||
# 检查NapCat
|
||||
check_napcat() {
|
||||
if command -v napcat &>/dev/null; then
|
||||
NAPCAT_INSTALLED=true
|
||||
else
|
||||
NAPCAT_INSTALLED=false
|
||||
fi
|
||||
}
|
||||
check_napcat
|
||||
|
||||
# 安装必要软件包
|
||||
install_packages() {
|
||||
missing_packages=()
|
||||
# 检查 common 及当前系统专属依赖
|
||||
for package in $(get_required_packages "$ID"); do
|
||||
case "$PKG_MANAGER" in
|
||||
apt)
|
||||
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
yum)
|
||||
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
pacman)
|
||||
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
brew)
|
||||
case "$package" in
|
||||
git)
|
||||
command -v git &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
gnupg)
|
||||
command -v gpg &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
python)
|
||||
command -v python3 &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
*)
|
||||
run_brew list --formula "$package" &>/dev/null || missing_packages+=("$package")
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ ${#missing_packages[@]} -gt 0 ]]; then
|
||||
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装?" 10 60
|
||||
if [[ $? -eq 0 ]]; then
|
||||
IS_INSTALL_DEPENDENCIES=true
|
||||
else
|
||||
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续?" 10 60 || exit 1
|
||||
fi
|
||||
fi
|
||||
}
|
||||
install_packages
|
||||
|
||||
# 安装NapCat
|
||||
install_napcat() {
|
||||
[[ $NAPCAT_INSTALLED == true ]] && return
|
||||
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat,是否安装?\n如果您想使用远程NapCat,请跳过此步。" 10 60 && {
|
||||
IS_INSTALL_NAPCAT=true
|
||||
}
|
||||
}
|
||||
|
||||
# 仅在 Linux 非 Arch 系统上安装 NapCat,macOS 仅支持远程 NapCat。
|
||||
if [[ "$ID" == "macos" ]]; then
|
||||
whiptail --title "⚠️ NapCat 安装提示" --msgbox "当前为 macOS,暂不支持自动安装 NapCat。\n如需使用 NapCat,请配置远程实例后再连接。 " 10 60
|
||||
elif [[ "$ID" != "arch" ]]; then
|
||||
install_napcat
|
||||
fi
|
||||
|
||||
# Python版本检查
|
||||
check_python() {
|
||||
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
|
||||
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# 如果没安装python则不检查python版本
|
||||
if command -v python3 &>/dev/null; then
|
||||
check_python
|
||||
fi
|
||||
|
||||
|
||||
# 选择分支
|
||||
choose_branch() {
|
||||
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
|
||||
"main" "稳定版本(推荐)" ON \
|
||||
"dev" "开发版(不知道什么意思就别选)" OFF \
|
||||
"classical" "经典版(0.6.0以前的版本)" OFF \
|
||||
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
|
||||
RETVAL=$?
|
||||
if [ $RETVAL -ne 0 ]; then
|
||||
whiptail --msgbox "🚫 操作取消!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$BRANCH" == "custom" ]]; then
|
||||
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
|
||||
RETVAL=$?
|
||||
if [ $RETVAL -ne 0 ]; then
|
||||
whiptail --msgbox "🚫 输入取消!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$BRANCH" ]]; then
|
||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
}
|
||||
choose_branch
|
||||
|
||||
# 选择安装路径
|
||||
choose_install_dir() {
|
||||
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录:" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
|
||||
[[ -z "$INSTALL_DIR" ]] && {
|
||||
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
|
||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
||||
}
|
||||
}
|
||||
choose_install_dir
|
||||
|
||||
# 确认安装
|
||||
confirm_install() {
|
||||
local confirm_msg="请确认以下更改:\n\n"
|
||||
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
|
||||
confirm_msg+="🔀 分支: $BRANCH\n"
|
||||
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
|
||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
|
||||
|
||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
|
||||
confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
|
||||
|
||||
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
|
||||
}
|
||||
confirm_install
|
||||
|
||||
# 开始安装
|
||||
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
|
||||
|
||||
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
|
||||
case "$PKG_MANAGER" in
|
||||
apt)
|
||||
apt update && apt install -y "${missing_packages[@]}"
|
||||
;;
|
||||
yum)
|
||||
yum install -y "${missing_packages[@]}" --nobest
|
||||
;;
|
||||
pacman)
|
||||
pacman -S --noconfirm "${missing_packages[@]}"
|
||||
;;
|
||||
brew)
|
||||
run_brew update && run_brew install "${missing_packages[@]}"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
|
||||
if [[ $IS_INSTALL_NAPCAT == true ]]; then
|
||||
echo -e "${GREEN}安装 NapCat...${RESET}"
|
||||
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}创建安装目录...${RESET}"
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
cd "$INSTALL_DIR" || exit 1
|
||||
|
||||
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
|
||||
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
|
||||
echo -e "${RED}克隆MaiCore仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
echo -e "${GREEN}A_Memorix 已内置到源码,无需初始化子模块。${RESET}"
|
||||
|
||||
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
|
||||
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
|
||||
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
|
||||
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
|
||||
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
|
||||
exit 1
|
||||
}
|
||||
|
||||
|
||||
echo -e "${GREEN}安装Python依赖...${RESET}"
|
||||
select_pypi_index_url
|
||||
pip install -r MaiBot/requirements.txt
|
||||
cd MaiBot
|
||||
pip install uv
|
||||
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -r requirements.txt
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}安装maim_message依赖...${RESET}"
|
||||
cd maim_message
|
||||
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -e .
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
|
||||
cd MaiBot-Napcat-Adapter
|
||||
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -r requirements.txt
|
||||
cd ..
|
||||
|
||||
echo -e "${GREEN}同意协议...${RESET}"
|
||||
|
||||
# 首先计算当前EULA的MD5值
|
||||
current_md5=$(compute_md5 "MaiBot/EULA.md")
|
||||
|
||||
# 首先计算当前隐私条款文件的哈希值
|
||||
current_md5_privacy=$(compute_md5 "MaiBot/PRIVACY.md")
|
||||
|
||||
echo -n "$current_md5" > MaiBot/eula.confirmed
|
||||
echo -n "$current_md5_privacy" > MaiBot/privacy.confirmed
|
||||
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
echo -e "${GREEN}创建 launchctl 服务...${RESET}"
|
||||
create_launchd_services
|
||||
stop_service "${SERVICE_NAME}" >/dev/null 2>&1 || true
|
||||
stop_service "${SERVICE_NAME_NBADAPTER}" >/dev/null 2>&1 || true
|
||||
else
|
||||
echo -e "${GREEN}创建系统服务...${RESET}"
|
||||
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
|
||||
[Unit]
|
||||
Description=MaiCore
|
||||
After=network.target ${SERVICE_NAME_NBADAPTER}.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=${INSTALL_DIR}/MaiBot
|
||||
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
|
||||
Restart=always
|
||||
RestartSec=10s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
|
||||
# [Unit]
|
||||
# Description=MaiCore WebUI
|
||||
# After=network.target ${SERVICE_NAME}.service
|
||||
|
||||
# [Service]
|
||||
# Type=simple
|
||||
# WorkingDirectory=${INSTALL_DIR}/MaiBot
|
||||
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
|
||||
# Restart=always
|
||||
# RestartSec=10s
|
||||
|
||||
# [Install]
|
||||
# WantedBy=multi-user.target
|
||||
# EOF
|
||||
|
||||
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
|
||||
[Unit]
|
||||
Description=MaiBot Napcat Adapter
|
||||
After=network.target mongod.service ${SERVICE_NAME}.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
|
||||
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
|
||||
Restart=always
|
||||
RestartSec=10s
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
EOF
|
||||
|
||||
systemctl daemon-reload
|
||||
fi
|
||||
|
||||
# 保存安装信息
|
||||
save_install_info
|
||||
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建 launchctl 服务:${LAUNCHD_LABEL_MAIN}、${LAUNCHD_LABEL_NBADAPTER}\n\n首次加载:launchctl bootstrap ${LAUNCHD_DOMAIN} ${LAUNCHD_PLIST_MAIN}\n重启服务:launchctl kickstart -k ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN}\n查看状态:launchctl print ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN}" 14 100
|
||||
else
|
||||
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建系统服务:${SERVICE_NAME}、${SERVICE_NAME_WEB}、${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务:\n启动服务:systemctl start ${SERVICE_NAME}\n查看状态:systemctl status ${SERVICE_NAME}" 14 60
|
||||
fi
|
||||
}
|
||||
|
||||
# ----------- 主执行流程 -----------
|
||||
# Linux 仍需 root,macOS 使用用户级 launchctl(无需 root)。
|
||||
if [[ "$IS_MACOS" == true && $(id -u) -eq 0 ]]; then
|
||||
echo -e "${RED}macOS 请勿使用 root/sudo 运行此脚本,请直接以当前登录用户执行。${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ "$IS_MACOS" != true && $(id -u) -ne 0 ]]; then
|
||||
echo -e "${RED}请使用root用户运行此脚本!${RESET}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 如果已安装显示菜单,并检查协议是否更新
|
||||
if check_installed; then
|
||||
load_install_info
|
||||
check_eula
|
||||
show_menu
|
||||
else
|
||||
run_installation
|
||||
# 安装完成后询问是否启动
|
||||
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务?" 10 60; then
|
||||
start_service "${SERVICE_NAME}"
|
||||
if [[ "$IS_MACOS" == true ]]; then
|
||||
whiptail --msgbox "✅ 服务已启动!\n使用 launchctl print ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN} 查看状态" 10 80
|
||||
else
|
||||
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
25
scripts/run_a_memorix_webui_backend.py
Normal file
25
scripts/run_a_memorix_webui_backend.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.A_memorix.host_service import a_memorix_host_service
|
||||
from src.webui.webui_server import get_webui_server
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
server = get_webui_server()
|
||||
await a_memorix_host_service.start()
|
||||
try:
|
||||
await server.start()
|
||||
finally:
|
||||
await a_memorix_host_service.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
51
scripts/run_lpmm.sh
Normal file
51
scripts/run_lpmm.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
# ==============================================
|
||||
# Environment Initialization
|
||||
# ==============================================
|
||||
|
||||
# Step 1: Locate project root directory
|
||||
SCRIPTS_DIR="scripts"
|
||||
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
|
||||
PROJECT_ROOT=$(cd "$SCRIPT_DIR/.." && pwd)
|
||||
|
||||
# Step 2: Verify scripts directory exists
|
||||
if [ ! -d "$PROJECT_ROOT/$SCRIPTS_DIR" ]; then
|
||||
echo "❌ Error: scripts directory not found in project root" >&2
|
||||
echo "Current path: $PROJECT_ROOT" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Step 3: Set up Python environment
|
||||
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
|
||||
cd "$PROJECT_ROOT" || {
|
||||
echo "❌ Failed to cd to project root: $PROJECT_ROOT" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Debug info
|
||||
echo "============================"
|
||||
echo "Project Root: $PROJECT_ROOT"
|
||||
echo "Python Path: $PYTHONPATH"
|
||||
echo "Working Dir: $(pwd)"
|
||||
echo "============================"
|
||||
|
||||
# ==============================================
|
||||
# Python Script Execution
|
||||
# ==============================================
|
||||
|
||||
run_python_script() {
|
||||
local script_name=$1
|
||||
echo "🔄 Running $script_name"
|
||||
if ! python3 "$SCRIPTS_DIR/$script_name"; then
|
||||
echo "❌ $script_name failed" >&2
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Execute scripts in order
|
||||
run_python_script "raw_data_preprocessor.py"
|
||||
run_python_script "info_extraction.py"
|
||||
run_python_script "import_openie.py"
|
||||
|
||||
echo "✅ All scripts completed successfully"
|
||||
21
scripts/sync_a_memorix_subtree.sh
Normal file
21
scripts/sync_a_memorix_subtree.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
MODE="${1:-pull}"
|
||||
REMOTE_URL="${2:-https://github.com/A-Dawn/A_memorix.git}"
|
||||
BRANCH="${3:-MaiBot_branch}"
|
||||
PREFIX="src/A_memorix"
|
||||
|
||||
case "$MODE" in
|
||||
add)
|
||||
git subtree add --prefix="$PREFIX" "$REMOTE_URL" "$BRANCH" --squash
|
||||
;;
|
||||
pull)
|
||||
git subtree pull --prefix="$PREFIX" "$REMOTE_URL" "$BRANCH" --squash
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 [add|pull] [remote_url] [branch]" >&2
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
459
scripts/test_memory_retrieval.py
Normal file
459
scripts/test_memory_retrieval.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
|
||||
|
||||
try:
|
||||
from maim_message import ChatStream, UserInfo, GroupInfo
|
||||
except Exception:
|
||||
@dataclass
|
||||
class ChatStream:
|
||||
stream_id: str
|
||||
platform: str
|
||||
user_info: UserInfo
|
||||
group_info: GroupInfo
|
||||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
memory_retrieval_module._process_single_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
|
||||
"""创建一个测试用的 ChatStream 对象"""
|
||||
user_info = UserInfo(
|
||||
platform="test",
|
||||
user_id="test_user",
|
||||
user_nickname="测试用户",
|
||||
)
|
||||
group_info = GroupInfo(
|
||||
platform="test",
|
||||
group_id="test_group",
|
||||
group_name="测试群组",
|
||||
)
|
||||
return ChatStream(
|
||||
stream_id=chat_id,
|
||||
platform="test",
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def format_thinking_steps(thinking_steps: list) -> str:
|
||||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
obs_str = str(obs)[:200]
|
||||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def test_memory_retrieval(
|
||||
question: str,
|
||||
chat_id: str = "test_memory_retrieval",
|
||||
context: str = "",
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print("\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=max_iterations,
|
||||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
"found_answer": found_answer,
|
||||
"answer": answer,
|
||||
"is_timeout": is_timeout,
|
||||
"elapsed_time": elapsed_time,
|
||||
"thinking_steps": thinking_steps,
|
||||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print("\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print("\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage["model_usage"]:
|
||||
print("\n[按模型统计]")
|
||||
for model_name, usage in token_usage["model_usage"].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print("\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="test_memory_retrieval",
|
||||
help="测试用的聊天ID(默认: test_memory_retrieval)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context",
|
||||
default="",
|
||||
help="上下文信息(可选)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
if max_iterations_input:
|
||||
try:
|
||||
max_iterations = int(max_iterations_input)
|
||||
if max_iterations <= 0:
|
||||
print("警告: 迭代次数必须大于0,将使用配置默认值")
|
||||
max_iterations = None
|
||||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
test_memory_retrieval(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
context=args.context,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
845
scripts/test_model_tool_call_params.py
Normal file
845
scripts/test_model_tool_call_params.py
Normal file
@@ -0,0 +1,845 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
||||
from src.config.config import config_manager # noqa: E402
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
||||
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
|
||||
from src.services.llm_service import generate # noqa: E402
|
||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallCase:
|
||||
"""Tool call 参数测试用例。"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
tool_definition: Dict[str, Any]
|
||||
expected_arguments: Dict[str, Any]
|
||||
|
||||
@property
|
||||
def tool_name(self) -> str:
|
||||
"""返回工具名称。"""
|
||||
if self.tool_definition.get("type") == "function":
|
||||
function_definition = self.tool_definition.get("function", {})
|
||||
return str(function_definition.get("name", "") or "")
|
||||
return str(self.tool_definition.get("name", "") or "")
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> Dict[str, Any]:
|
||||
"""返回参数 Schema。"""
|
||||
if self.tool_definition.get("type") == "function":
|
||||
function_definition = self.tool_definition.get("function", {})
|
||||
parameters = function_definition.get("parameters", {})
|
||||
return parameters if isinstance(parameters, dict) else {}
|
||||
parameters = self.tool_definition.get("parameters", {})
|
||||
return parameters if isinstance(parameters, dict) else {}
|
||||
|
||||
def build_messages(self) -> List[Dict[str, Any]]:
|
||||
"""构造测试消息。"""
|
||||
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
|
||||
system_prompt = (
|
||||
"你正在执行严格的工具调用参数兼容性测试。"
|
||||
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
|
||||
)
|
||||
user_prompt = (
|
||||
f"请立刻调用工具 `{self.tool_name}`。\n"
|
||||
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
|
||||
"不要输出任何解释文本,只返回工具调用。\n"
|
||||
f"{expected_json}"
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeTarget:
|
||||
"""单个待测试模型目标。"""
|
||||
|
||||
task_name: str
|
||||
model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeResult:
|
||||
"""单次测试结果。"""
|
||||
|
||||
task_name: str
|
||||
target_model_name: str
|
||||
actual_model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
case_name: str
|
||||
attempt: int
|
||||
success: bool
|
||||
elapsed_seconds: float
|
||||
errors: List[str]
|
||||
warnings: List[str]
|
||||
response_text: str
|
||||
reasoning_text: str
|
||||
tool_calls: List[Dict[str, Any]]
|
||||
|
||||
|
||||
def _ensure_utf8_console() -> None:
|
||||
"""尽量将控制台编码切到 UTF-8。"""
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构造 OpenAI 风格 function tool 定义。"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_default_cases() -> List[ToolCallCase]:
|
||||
"""构造默认测试用例。"""
|
||||
simple_expected_arguments = {
|
||||
"request_id": "probe-simple-001",
|
||||
"count": 7,
|
||||
"enabled": True,
|
||||
"mode": "strict",
|
||||
"ratio": 2.5,
|
||||
}
|
||||
simple_parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request_id": {"type": "string", "description": "请求 ID"},
|
||||
"count": {"type": "integer", "description": "数量"},
|
||||
"enabled": {"type": "boolean", "description": "是否启用"},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"description": "模式",
|
||||
"enum": ["strict", "loose"],
|
||||
},
|
||||
"ratio": {"type": "number", "description": "比例"},
|
||||
},
|
||||
"required": ["request_id", "count", "enabled", "mode", "ratio"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
nested_expected_arguments = {
|
||||
"request_id": "probe-nested-001",
|
||||
"notify": False,
|
||||
"profile": {
|
||||
"channel": "stable",
|
||||
"priority": 2,
|
||||
},
|
||||
"tags": ["alpha", "beta", "gamma"],
|
||||
"items": [
|
||||
{"count": 2, "name": "apple"},
|
||||
{"count": 5, "name": "banana"},
|
||||
],
|
||||
}
|
||||
nested_parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request_id": {"type": "string", "description": "请求 ID"},
|
||||
"notify": {"type": "boolean", "description": "是否通知"},
|
||||
"profile": {
|
||||
"type": "object",
|
||||
"description": "配置对象",
|
||||
"properties": {
|
||||
"channel": {"type": "string", "description": "渠道"},
|
||||
"priority": {"type": "integer", "description": "优先级"},
|
||||
},
|
||||
"required": ["channel", "priority"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "标签列表",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"description": "条目列表",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer", "description": "数量"},
|
||||
"name": {"type": "string", "description": "名称"},
|
||||
},
|
||||
"required": ["count", "name"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["request_id", "notify", "profile", "tags", "items"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
return [
|
||||
ToolCallCase(
|
||||
name="simple",
|
||||
description="标量参数类型校验",
|
||||
tool_definition=_build_function_tool(
|
||||
name="record_simple_probe",
|
||||
description="记录简单参数探测结果",
|
||||
parameters=simple_parameters,
|
||||
),
|
||||
expected_arguments=simple_expected_arguments,
|
||||
),
|
||||
ToolCallCase(
|
||||
name="nested",
|
||||
description="嵌套对象与数组参数校验",
|
||||
tool_definition=_build_function_tool(
|
||||
name="record_nested_probe",
|
||||
description="记录嵌套参数探测结果",
|
||||
parameters=nested_parameters,
|
||||
),
|
||||
expected_arguments=nested_expected_arguments,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
||||
"""解析命令行中的多值参数。"""
|
||||
parsed_values: List[str] = []
|
||||
for raw_value in raw_values or []:
|
||||
for item in str(raw_value).split(","):
|
||||
normalized_item = item.strip()
|
||||
if normalized_item:
|
||||
parsed_values.append(normalized_item)
|
||||
return parsed_values
|
||||
|
||||
|
||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
||||
"""构造模型名称到模型配置的映射。"""
|
||||
return {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
|
||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
||||
"""构造 Provider 名称到配置的映射。"""
|
||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
||||
|
||||
|
||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
||||
"""选择默认任务名。"""
|
||||
if "utils" in task_names:
|
||||
return "utils"
|
||||
if not task_names:
|
||||
raise ValueError("当前没有可用的任务配置")
|
||||
return str(task_names[0])
|
||||
|
||||
|
||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
||||
"""根据命令行参数解析待测试目标。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
provider_map = _build_provider_map()
|
||||
|
||||
if not available_tasks:
|
||||
raise ValueError("未找到任何可用的模型任务配置")
|
||||
|
||||
if task_filters:
|
||||
selected_task_names = []
|
||||
for task_name in task_filters:
|
||||
if task_name not in available_tasks:
|
||||
raise ValueError(f"未找到任务 `{task_name}`")
|
||||
selected_task_names.append(task_name)
|
||||
else:
|
||||
selected_task_names = [
|
||||
task_name
|
||||
for task_name in available_tasks
|
||||
if task_name not in DEFAULT_SKIP_TASKS
|
||||
]
|
||||
|
||||
if not selected_task_names:
|
||||
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
|
||||
|
||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
||||
resolved_targets: List[ProbeTarget] = []
|
||||
seen_models: set[str] = set()
|
||||
|
||||
if model_filters:
|
||||
model_names = list(model_filters)
|
||||
else:
|
||||
model_names = []
|
||||
for task_name in selected_task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
for model_name in task_config.model_list:
|
||||
if model_name not in model_names:
|
||||
model_names.append(model_name)
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name in seen_models:
|
||||
continue
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"未找到模型 `{model_name}`")
|
||||
|
||||
target_task_name = ""
|
||||
for task_name in selected_task_names:
|
||||
if model_name in available_tasks[task_name].model_list:
|
||||
target_task_name = task_name
|
||||
break
|
||||
if not target_task_name:
|
||||
target_task_name = default_task_name
|
||||
|
||||
model_info = model_map[model_name]
|
||||
provider_info = provider_map[model_info.api_provider]
|
||||
resolved_targets.append(
|
||||
ProbeTarget(
|
||||
task_name=target_task_name,
|
||||
model_name=model_name,
|
||||
provider_name=provider_info.name,
|
||||
client_type=provider_info.client_type,
|
||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
||||
)
|
||||
)
|
||||
seen_models.add(model_name)
|
||||
|
||||
return resolved_targets
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
||||
"""临时将某个任务锁定到单模型。"""
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
task_config = getattr(model_task_config, task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
||||
|
||||
original_model_list = list(task_config.model_list)
|
||||
original_selection_strategy = task_config.selection_strategy
|
||||
task_config.model_list = [model_name]
|
||||
task_config.selection_strategy = "balance"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_config.model_list = original_model_list
|
||||
task_config.selection_strategy = original_selection_strategy
|
||||
|
||||
|
||||
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
|
||||
"""序列化工具调用结果。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
return [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": dict(tool_call.args or {}),
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _is_integer_value(value: Any) -> bool:
|
||||
"""判断是否为整数类型且排除布尔值。"""
|
||||
return isinstance(value, int) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _is_number_value(value: Any) -> bool:
|
||||
"""判断是否为数值类型且排除布尔值。"""
|
||||
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _schema_type(schema: Dict[str, Any]) -> str:
|
||||
"""解析 Schema 的类型。"""
|
||||
schema_type = str(schema.get("type", "") or "").strip()
|
||||
if schema_type:
|
||||
return schema_type
|
||||
if "properties" in schema or "required" in schema:
|
||||
return "object"
|
||||
return ""
|
||||
|
||||
|
||||
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
|
||||
"""按简化 JSON Schema 校验工具参数。"""
|
||||
errors: List[str] = []
|
||||
schema_type = _schema_type(schema)
|
||||
|
||||
if "enum" in schema and actual_value not in schema["enum"]:
|
||||
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
|
||||
|
||||
if schema_type == "string":
|
||||
if not isinstance(actual_value, str):
|
||||
errors.append(f"{path} 类型错误,期望 string,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "integer":
|
||||
if not _is_integer_value(actual_value):
|
||||
errors.append(f"{path} 类型错误,期望 integer,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "number":
|
||||
if not _is_number_value(actual_value):
|
||||
errors.append(f"{path} 类型错误,期望 number,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "boolean":
|
||||
if not isinstance(actual_value, bool):
|
||||
errors.append(f"{path} 类型错误,期望 boolean,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "array":
|
||||
if not isinstance(actual_value, list):
|
||||
errors.append(f"{path} 类型错误,期望 array,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
item_schema = schema.get("items")
|
||||
if isinstance(item_schema, dict):
|
||||
for index, item in enumerate(actual_value):
|
||||
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
|
||||
return errors
|
||||
|
||||
if schema_type == "object":
|
||||
if not isinstance(actual_value, dict):
|
||||
errors.append(f"{path} 类型错误,期望 object,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = [str(item) for item in schema.get("required", [])]
|
||||
for required_field in required_fields:
|
||||
if required_field not in actual_value:
|
||||
errors.append(f"{path}.{required_field} 缺少必填字段")
|
||||
|
||||
for field_name, field_value in actual_value.items():
|
||||
field_path = f"{path}.{field_name}"
|
||||
field_schema = properties.get(field_name)
|
||||
if isinstance(field_schema, dict):
|
||||
errors.extend(_validate_schema(field_schema, field_value, field_path))
|
||||
continue
|
||||
|
||||
additional_properties = schema.get("additionalProperties", True)
|
||||
if additional_properties is False:
|
||||
errors.append(f"{field_path} 是未定义字段")
|
||||
elif isinstance(additional_properties, dict):
|
||||
errors.extend(_validate_schema(additional_properties, field_value, field_path))
|
||||
return errors
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
|
||||
"""递归比较实际值与期望值是否完全一致。"""
|
||||
errors: List[str] = []
|
||||
|
||||
if isinstance(expected_value, dict):
|
||||
if not isinstance(actual_value, dict):
|
||||
return [f"{path} 值不一致,期望 object,实际为 {type(actual_value).__name__}"]
|
||||
|
||||
expected_keys = set(expected_value.keys())
|
||||
actual_keys = set(actual_value.keys())
|
||||
for missing_key in sorted(expected_keys - actual_keys):
|
||||
errors.append(f"{path}.{missing_key} 缺少期望字段")
|
||||
for extra_key in sorted(actual_keys - expected_keys):
|
||||
errors.append(f"{path}.{extra_key} 出现了额外字段")
|
||||
for shared_key in sorted(expected_keys & actual_keys):
|
||||
errors.extend(
|
||||
_compare_expected_values(
|
||||
expected_value[shared_key],
|
||||
actual_value[shared_key],
|
||||
f"{path}.{shared_key}",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, list):
|
||||
if not isinstance(actual_value, list):
|
||||
return [f"{path} 值不一致,期望 array,实际为 {type(actual_value).__name__}"]
|
||||
|
||||
if len(expected_value) != len(actual_value):
|
||||
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
|
||||
for index, (expected_item, actual_item) in enumerate(
|
||||
zip(expected_value, actual_value, strict=False)
|
||||
):
|
||||
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, bool):
|
||||
if not isinstance(actual_value, bool) or actual_value is not expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if _is_integer_value(expected_value):
|
||||
if not _is_integer_value(actual_value) or actual_value != expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, float):
|
||||
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if expected_value != actual_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
|
||||
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
|
||||
"""优先选择同名工具调用,否则回退到第一条。"""
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.func_name == expected_tool_name:
|
||||
return tool_call
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
def _validate_service_result(
|
||||
service_result: LLMServiceResult,
|
||||
target: ProbeTarget,
|
||||
case: ToolCallCase,
|
||||
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
||||
"""校验服务层返回结果。"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
completion = service_result.completion
|
||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
||||
|
||||
if not service_result.success:
|
||||
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if completion.model_name and completion.model_name != target.model_name:
|
||||
errors.append(
|
||||
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
|
||||
)
|
||||
|
||||
tool_calls = completion.tool_calls or []
|
||||
if not tool_calls:
|
||||
errors.append("模型未返回 tool_calls")
|
||||
if completion.response.strip():
|
||||
warnings.append("模型返回了自然语言文本而不是工具调用")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if len(tool_calls) != 1:
|
||||
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls,预期为 1 个")
|
||||
|
||||
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
|
||||
if selected_tool_call.func_name != case.tool_name:
|
||||
errors.append(
|
||||
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
|
||||
)
|
||||
|
||||
actual_arguments = selected_tool_call.args
|
||||
if not isinstance(actual_arguments, dict):
|
||||
errors.append("工具参数未被解析为对象")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
|
||||
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
|
||||
|
||||
if completion.response.strip():
|
||||
warnings.append("模型同时返回了自然语言文本")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
|
||||
async def _run_single_probe(
|
||||
target: ProbeTarget,
|
||||
case: ToolCallCase,
|
||||
attempt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> ProbeResult:
|
||||
"""执行单次工具调用参数探测。"""
|
||||
request = LLMServiceRequest(
|
||||
task_name=target.task_name,
|
||||
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
|
||||
prompt=case.build_messages(),
|
||||
tool_options=[case.tool_definition],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
with _pin_task_to_model(target.task_name, target.model_name):
|
||||
service_result = await generate(request)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
|
||||
completion = service_result.completion
|
||||
return ProbeResult(
|
||||
task_name=target.task_name,
|
||||
target_model_name=target.model_name,
|
||||
actual_model_name=completion.model_name,
|
||||
provider_name=target.provider_name,
|
||||
client_type=target.client_type,
|
||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
||||
case_name=case.name,
|
||||
attempt=attempt,
|
||||
success=not errors,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
response_text=completion.response,
|
||||
reasoning_text=completion.reasoning,
|
||||
tool_calls=serialized_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
||||
"""打印待测试目标。"""
|
||||
print("待测试目标:")
|
||||
for index, target in enumerate(targets, start=1):
|
||||
print(
|
||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
||||
f"provider={target.provider_name} | client={target.client_type} | "
|
||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
||||
)
|
||||
|
||||
|
||||
def _print_available_targets() -> None:
|
||||
"""打印当前可用任务与模型。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
task_names = list(available_tasks.keys())
|
||||
|
||||
print("当前可用任务:")
|
||||
for task_name in task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
||||
|
||||
referenced_models = {
|
||||
model_name
|
||||
for task_config in available_tasks.values()
|
||||
for model_name in task_config.model_list
|
||||
}
|
||||
|
||||
print("\n当前配置中的模型:")
|
||||
for model_name, model_info in model_map.items():
|
||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
||||
print(
|
||||
f"- {model_name}: provider={model_info.api_provider}, "
|
||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
||||
)
|
||||
|
||||
|
||||
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
|
||||
"""根据参数筛选测试用例。"""
|
||||
all_cases = {case.name: case for case in _build_default_cases()}
|
||||
if not case_filters:
|
||||
return list(all_cases.values())
|
||||
|
||||
selected_cases: List[ToolCallCase] = []
|
||||
for case_name in case_filters:
|
||||
if case_name not in all_cases:
|
||||
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
|
||||
selected_cases.append(all_cases[case_name])
|
||||
return selected_cases
|
||||
|
||||
|
||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
||||
"""打印单次结果。"""
|
||||
status_text = "PASS" if result.success else "FAIL"
|
||||
print(
|
||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
||||
)
|
||||
if result.errors:
|
||||
for error in result.errors:
|
||||
print(f" ERROR: {error}")
|
||||
if result.warnings:
|
||||
for warning in result.warnings:
|
||||
print(f" WARN: {warning}")
|
||||
if result.tool_calls:
|
||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
||||
if show_response and result.response_text.strip():
|
||||
print(f" response: {result.response_text}")
|
||||
|
||||
|
||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
||||
"""构造结果摘要。"""
|
||||
total_count = len(results)
|
||||
passed_count = sum(1 for result in results if result.success)
|
||||
failed_count = total_count - passed_count
|
||||
failed_items = [
|
||||
{
|
||||
"model_name": result.target_model_name,
|
||||
"case_name": result.case_name,
|
||||
"attempt": result.attempt,
|
||||
"errors": list(result.errors),
|
||||
}
|
||||
for result in results
|
||||
if not result.success
|
||||
]
|
||||
return {
|
||||
"total": total_count,
|
||||
"passed": passed_count,
|
||||
"failed": failed_count,
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
|
||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
||||
"""将测试结果写入 JSON 文件。"""
|
||||
output_path = Path(json_out).expanduser().resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"summary": _build_summary(results),
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\n结果已写入: {output_path}")
|
||||
|
||||
|
||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
||||
"""执行所有探测请求。"""
|
||||
task_filters = _parse_multi_value_args(args.task)
|
||||
model_filters = _parse_multi_value_args(args.model)
|
||||
case_filters = _parse_multi_value_args(args.case)
|
||||
|
||||
selected_cases = _select_cases(case_filters)
|
||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
||||
|
||||
_print_targets(targets)
|
||||
print("")
|
||||
|
||||
results: List[ProbeResult] = []
|
||||
for target in targets:
|
||||
for attempt in range(1, args.repeat + 1):
|
||||
for case in selected_cases:
|
||||
print(
|
||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
||||
f"case={case.name}, attempt={attempt}"
|
||||
)
|
||||
result = await _run_single_probe(
|
||||
target=target,
|
||||
case=case,
|
||||
attempt=attempt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
_print_single_result(result, args.show_response)
|
||||
print("")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _build_parser() -> ArgumentParser:
|
||||
"""构造命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
|
||||
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
action="append",
|
||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case",
|
||||
action="append",
|
||||
help="指定测试用例名,可选 simple、nested;不传则运行全部默认用例",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="每个模型每个用例重复测试次数,默认 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="单次测试的最大输出 token 数,默认 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="单次测试温度,默认 0.0 以尽量提高稳定性",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-task",
|
||||
default="utils",
|
||||
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-out",
|
||||
help="可选,将结果写入指定 JSON 文件",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-targets",
|
||||
action="store_true",
|
||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-response",
|
||||
action="store_true",
|
||||
help="打印模型返回的自然语言文本内容",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
_ensure_utf8_console()
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.repeat < 1:
|
||||
parser.error("--repeat 必须大于等于 1")
|
||||
if args.max_tokens < 1:
|
||||
parser.error("--max-tokens 必须大于等于 1")
|
||||
|
||||
if args.list_targets:
|
||||
_print_available_targets()
|
||||
return 0
|
||||
|
||||
results = asyncio.run(_run_probes(args))
|
||||
summary = _build_summary(results)
|
||||
|
||||
print("测试摘要:")
|
||||
print(
|
||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
||||
)
|
||||
if summary["failed_items"]:
|
||||
print("失败明细:")
|
||||
for failed_item in summary["failed_items"]:
|
||||
print(
|
||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
_write_json_report(args.json_out, results)
|
||||
|
||||
return 0 if summary["failed"] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
777
scripts/test_tool_call_api_matrix.py
Normal file
777
scripts/test_tool_call_api_matrix.py
Normal file
@@ -0,0 +1,777 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
||||
from src.config.config import config_manager # noqa: E402
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
||||
from src.services.llm_service import generate # noqa: E402
|
||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeTarget:
|
||||
"""单个待测试模型目标。"""
|
||||
|
||||
task_name: str
|
||||
model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallScenario:
|
||||
"""工具调用 API 场景定义。"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
prompt: List[Dict[str, Any]]
|
||||
tool_options: List[Dict[str, Any]] | None = None
|
||||
expect_tool_calls: bool | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeResult:
|
||||
"""单次 API 探测结果。"""
|
||||
|
||||
task_name: str
|
||||
target_model_name: str
|
||||
actual_model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
case_name: str
|
||||
attempt: int
|
||||
success: bool
|
||||
elapsed_seconds: float
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
response_text: str = ""
|
||||
reasoning_text: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _ensure_utf8_console() -> None:
|
||||
"""尽量将控制台编码切换为 UTF-8。"""
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构造 OpenAI 风格 function tool。"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_probe_tools() -> List[Dict[str, Any]]:
|
||||
"""构造通用测试工具。"""
|
||||
weather_tool = _build_function_tool(
|
||||
name="lookup_weather",
|
||||
description="查询指定城市天气。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "温度单位",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
|
||||
},
|
||||
"required": ["city", "unit", "include_forecast"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
search_tool = _build_function_tool(
|
||||
name="search_docs",
|
||||
description="搜索内部知识库。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"},
|
||||
"top_k": {"type": "integer", "description": "返回条数"},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"description": "过滤条件",
|
||||
"properties": {
|
||||
"scope": {"type": "string", "description": "搜索范围"},
|
||||
"tag": {"type": "string", "description": "标签"},
|
||||
},
|
||||
"required": ["scope", "tag"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"required": ["query", "top_k", "filters"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
return [weather_tool, search_tool]
|
||||
|
||||
|
||||
def _build_default_scenarios() -> List[ToolCallScenario]:
|
||||
"""构造默认测试场景。"""
|
||||
tools = _build_probe_tools()
|
||||
weather_tool = tools[0]
|
||||
search_tool = tools[1]
|
||||
|
||||
history_tool_call = {
|
||||
"id": "call_hist_weather_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"arguments": {
|
||||
"city": "上海",
|
||||
"unit": "celsius",
|
||||
"include_forecast": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
nested_history_tool_call = {
|
||||
"id": "call_hist_search_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_docs",
|
||||
"arguments": {
|
||||
"query": "工具调用兼容性",
|
||||
"top_k": 3,
|
||||
"filters": {
|
||||
"scope": "internal",
|
||||
"tag": "tool-call",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return [
|
||||
ToolCallScenario(
|
||||
name="fresh_tool_call",
|
||||
description="首轮普通工具调用请求。",
|
||||
prompt=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你正在执行工具调用连通性测试。"
|
||||
"如果能调用工具,就优先调用最合适的工具。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请查询上海天气,并使用工具给出参数。",
|
||||
},
|
||||
],
|
||||
tool_options=[weather_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_with_content",
|
||||
description="历史 assistant 同时包含文本和 tool_calls,当前轮不再提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气,再继续回答。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续说,别丢掉上下文。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_without_content",
|
||||
description="历史 assistant 只有 tool_calls,没有文本内容。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_result_followup",
|
||||
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "多云",
|
||||
"temperature_c": 24,
|
||||
"forecast": ["晴", "小雨"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "结合上面的查询结果继续总结。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_multiple_tool_calls_and_results",
|
||||
description="历史中包含多个 tool_calls 与多条 tool 结果。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我分两步查询。",
|
||||
"tool_calls": [history_tool_call, nested_history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "阴",
|
||||
"temperature_c": 22,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_search_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"items": [
|
||||
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
|
||||
"部分 provider 在历史消息回放时兼容性较弱",
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "继续整合上面的两个结果。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_calls_with_current_tools",
|
||||
description="保留历史 tool_calls,同时当前轮仍然提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "晴",
|
||||
"temperature_c": 26,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
|
||||
],
|
||||
tool_options=[search_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
||||
"""解析命令行中的多值参数。"""
|
||||
parsed_values: List[str] = []
|
||||
for raw_value in raw_values or []:
|
||||
for item in str(raw_value).split(","):
|
||||
normalized_item = item.strip()
|
||||
if normalized_item:
|
||||
parsed_values.append(normalized_item)
|
||||
return parsed_values
|
||||
|
||||
|
||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
||||
"""构造模型名到模型配置的映射。"""
|
||||
return {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
|
||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
||||
"""构造 Provider 名称到配置的映射。"""
|
||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
||||
|
||||
|
||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
||||
"""选择默认任务名。"""
|
||||
if "utils" in task_names:
|
||||
return "utils"
|
||||
if not task_names:
|
||||
raise ValueError("当前没有可用的任务配置")
|
||||
return str(task_names[0])
|
||||
|
||||
|
||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
||||
"""根据命令行参数解析待测试目标。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
provider_map = _build_provider_map()
|
||||
|
||||
if not available_tasks:
|
||||
raise ValueError("未找到任何可用的模型任务配置")
|
||||
|
||||
if task_filters:
|
||||
selected_task_names = []
|
||||
for task_name in task_filters:
|
||||
if task_name not in available_tasks:
|
||||
raise ValueError(f"未找到任务 `{task_name}`")
|
||||
selected_task_names.append(task_name)
|
||||
else:
|
||||
selected_task_names = [
|
||||
task_name
|
||||
for task_name in available_tasks
|
||||
if task_name not in DEFAULT_SKIP_TASKS
|
||||
]
|
||||
|
||||
if not selected_task_names:
|
||||
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
|
||||
|
||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
||||
resolved_targets: List[ProbeTarget] = []
|
||||
seen_models: set[str] = set()
|
||||
|
||||
if model_filters:
|
||||
model_names = list(model_filters)
|
||||
else:
|
||||
model_names = []
|
||||
for task_name in selected_task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
for model_name in task_config.model_list:
|
||||
if model_name not in model_names:
|
||||
model_names.append(model_name)
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name in seen_models:
|
||||
continue
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"未找到模型 `{model_name}`")
|
||||
|
||||
target_task_name = ""
|
||||
for task_name in selected_task_names:
|
||||
if model_name in available_tasks[task_name].model_list:
|
||||
target_task_name = task_name
|
||||
break
|
||||
if not target_task_name:
|
||||
target_task_name = default_task_name
|
||||
|
||||
model_info = model_map[model_name]
|
||||
provider_info = provider_map[model_info.api_provider]
|
||||
resolved_targets.append(
|
||||
ProbeTarget(
|
||||
task_name=target_task_name,
|
||||
model_name=model_name,
|
||||
provider_name=provider_info.name,
|
||||
client_type=provider_info.client_type,
|
||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
||||
)
|
||||
)
|
||||
seen_models.add(model_name)
|
||||
|
||||
return resolved_targets
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
||||
"""临时将某个任务锁定到单模型。"""
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
task_config = getattr(model_task_config, task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
||||
|
||||
original_model_list = list(task_config.model_list)
|
||||
original_selection_strategy = task_config.selection_strategy
|
||||
task_config.model_list = [model_name]
|
||||
task_config.selection_strategy = "balance"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_config.model_list = original_model_list
|
||||
task_config.selection_strategy = original_selection_strategy
|
||||
|
||||
|
||||
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
|
||||
"""序列化返回中的工具调用。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
serialized_items: List[Dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
serialized_items.append(
|
||||
{
|
||||
"id": getattr(tool_call, "call_id", ""),
|
||||
"function": {
|
||||
"name": getattr(tool_call, "func_name", ""),
|
||||
"arguments": dict(getattr(tool_call, "args", {}) or {}),
|
||||
},
|
||||
**(
|
||||
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
|
||||
if getattr(tool_call, "extra_content", None)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return serialized_items
|
||||
|
||||
|
||||
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
||||
"""校验服务结果。"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
completion = service_result.completion
|
||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
||||
|
||||
if not service_result.success:
|
||||
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if scenario.expect_tool_calls is True and not serialized_tool_calls:
|
||||
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
|
||||
if scenario.expect_tool_calls is False and serialized_tool_calls:
|
||||
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
|
||||
if completion.response.strip():
|
||||
warnings.append("模型返回了可见文本")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
|
||||
async def _run_single_probe(
|
||||
target: ProbeTarget,
|
||||
scenario: ToolCallScenario,
|
||||
attempt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> ProbeResult:
|
||||
"""执行单次 API 探测。"""
|
||||
request = LLMServiceRequest(
|
||||
task_name=target.task_name,
|
||||
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
|
||||
prompt=scenario.prompt,
|
||||
tool_options=scenario.tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
with _pin_task_to_model(target.task_name, target.model_name):
|
||||
service_result = await generate(request)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
|
||||
completion = service_result.completion
|
||||
return ProbeResult(
|
||||
task_name=target.task_name,
|
||||
target_model_name=target.model_name,
|
||||
actual_model_name=completion.model_name,
|
||||
provider_name=target.provider_name,
|
||||
client_type=target.client_type,
|
||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
||||
case_name=scenario.name,
|
||||
attempt=attempt,
|
||||
success=not errors,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
response_text=completion.response,
|
||||
reasoning_text=completion.reasoning,
|
||||
tool_calls=serialized_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
||||
"""打印待测试目标。"""
|
||||
print("待测试目标:")
|
||||
for index, target in enumerate(targets, start=1):
|
||||
print(
|
||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
||||
f"provider={target.provider_name} | client={target.client_type} | "
|
||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
||||
)
|
||||
|
||||
|
||||
def _print_available_targets() -> None:
|
||||
"""打印当前可用任务与模型。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
task_names = list(available_tasks.keys())
|
||||
|
||||
print("当前可用任务:")
|
||||
for task_name in task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
||||
|
||||
referenced_models = {
|
||||
model_name
|
||||
for task_config in available_tasks.values()
|
||||
for model_name in task_config.model_list
|
||||
}
|
||||
|
||||
print("\n当前配置中的模型:")
|
||||
for model_name, model_info in model_map.items():
|
||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
||||
print(
|
||||
f"- {model_name}: provider={model_info.api_provider}, "
|
||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
||||
)
|
||||
|
||||
|
||||
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
|
||||
"""按名称筛选测试场景。"""
|
||||
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
|
||||
if not case_filters:
|
||||
return list(all_scenarios.values())
|
||||
|
||||
selected_scenarios: List[ToolCallScenario] = []
|
||||
for case_name in case_filters:
|
||||
if case_name not in all_scenarios:
|
||||
raise ValueError(
|
||||
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
|
||||
)
|
||||
selected_scenarios.append(all_scenarios[case_name])
|
||||
return selected_scenarios
|
||||
|
||||
|
||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
||||
"""打印单次结果。"""
|
||||
status_text = "PASS" if result.success else "FAIL"
|
||||
print(
|
||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
||||
)
|
||||
if result.errors:
|
||||
for error in result.errors:
|
||||
print(f" ERROR: {error}")
|
||||
if result.warnings:
|
||||
for warning in result.warnings:
|
||||
print(f" WARN: {warning}")
|
||||
if result.tool_calls:
|
||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
||||
if show_response and result.response_text.strip():
|
||||
print(f" response: {result.response_text}")
|
||||
|
||||
|
||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
||||
"""构造结果摘要。"""
|
||||
total_count = len(results)
|
||||
passed_count = sum(1 for result in results if result.success)
|
||||
failed_count = total_count - passed_count
|
||||
failed_items = [
|
||||
{
|
||||
"model_name": result.target_model_name,
|
||||
"case_name": result.case_name,
|
||||
"attempt": result.attempt,
|
||||
"errors": list(result.errors),
|
||||
}
|
||||
for result in results
|
||||
if not result.success
|
||||
]
|
||||
return {
|
||||
"total": total_count,
|
||||
"passed": passed_count,
|
||||
"failed": failed_count,
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
|
||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
||||
"""将测试结果写入 JSON 文件。"""
|
||||
output_path = Path(json_out).expanduser().resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"summary": _build_summary(results),
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\n结果已写入: {output_path}")
|
||||
|
||||
|
||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
||||
"""执行所有探测请求。"""
|
||||
task_filters = _parse_multi_value_args(args.task)
|
||||
model_filters = _parse_multi_value_args(args.model)
|
||||
case_filters = _parse_multi_value_args(args.case)
|
||||
|
||||
selected_scenarios = _select_scenarios(case_filters)
|
||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
||||
|
||||
_print_targets(targets)
|
||||
print("")
|
||||
|
||||
results: List[ProbeResult] = []
|
||||
for target in targets:
|
||||
for attempt in range(1, args.repeat + 1):
|
||||
for scenario in selected_scenarios:
|
||||
print(
|
||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
||||
f"case={scenario.name}, attempt={attempt}"
|
||||
)
|
||||
result = await _run_single_probe(
|
||||
target=target,
|
||||
scenario=scenario,
|
||||
attempt=attempt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
_print_single_result(result, args.show_response)
|
||||
print("")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _build_parser() -> ArgumentParser:
|
||||
"""构造命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
|
||||
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
action="append",
|
||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case",
|
||||
action="append",
|
||||
help=(
|
||||
"指定测试场景名,可选值包括 "
|
||||
"fresh_tool_call、history_assistant_tool_calls_with_content、"
|
||||
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
|
||||
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="每个模型每个场景重复测试次数,默认 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="单次测试的最大输出 token 数,默认 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="单次测试温度,默认 0.0,以尽量提高稳定性",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-task",
|
||||
default="utils",
|
||||
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-out",
|
||||
help="可选,将结果写入指定 JSON 文件",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-targets",
|
||||
action="store_true",
|
||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-response",
|
||||
action="store_true",
|
||||
help="打印模型返回的可见文本内容",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
_ensure_utf8_console()
|
||||
config_manager.initialize()
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.repeat < 1:
|
||||
parser.error("--repeat 必须大于等于 1")
|
||||
if args.max_tokens < 1:
|
||||
parser.error("--max-tokens 必须大于等于 1")
|
||||
|
||||
if args.list_targets:
|
||||
_print_available_targets()
|
||||
return 0
|
||||
|
||||
results = asyncio.run(_run_probes(args))
|
||||
summary = _build_summary(results)
|
||||
|
||||
print("测试摘要:")
|
||||
print(
|
||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
||||
)
|
||||
if summary["failed_items"]:
|
||||
print("失败明细:")
|
||||
for failed_item in summary["failed_items"]:
|
||||
print(
|
||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
_write_json_report(args.json_out, results)
|
||||
|
||||
return 0 if summary["failed"] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
83
scripts/verify_a_memorix_webui.sh
Normal file
83
scripts/verify_a_memorix_webui.sh
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
DASHBOARD_ROOT="$REPO_ROOT/dashboard"
|
||||
OUTPUT_DIR="${MAIBOT_UI_SNAPSHOT_DIR:-$REPO_ROOT/tmp/ui-snapshots/a_memorix-electron}"
|
||||
PYTHON_BIN="${MAIBOT_PYTHON_BIN:-$REPO_ROOT/../../.venv/bin/python}"
|
||||
ELECTRON_BIN="${MAIBOT_ELECTRON_BIN:-$DASHBOARD_ROOT/node_modules/electron/dist/Electron.app/Contents/MacOS/Electron}"
|
||||
DRIVER_SCRIPT="$DASHBOARD_ROOT/scripts/a_memorix_electron_validate.cjs"
|
||||
BACKEND_SCRIPT="$REPO_ROOT/scripts/run_a_memorix_webui_backend.py"
|
||||
BACKEND_HOST="${MAIBOT_WEBUI_HOST:-127.0.0.1}"
|
||||
BACKEND_PORT="${MAIBOT_WEBUI_PORT:-8001}"
|
||||
DASHBOARD_HOST="${MAIBOT_DASHBOARD_HOST:-127.0.0.1}"
|
||||
DASHBOARD_PORT="${MAIBOT_DASHBOARD_PORT:-7999}"
|
||||
BACKEND_URL="http://${BACKEND_HOST}:${BACKEND_PORT}"
|
||||
DASHBOARD_URL="http://${DASHBOARD_HOST}:${DASHBOARD_PORT}"
|
||||
REUSE_SERVICES="${MAIBOT_UI_REUSE_SERVICES:-0}"
|
||||
|
||||
BACKEND_PID=""
|
||||
DASHBOARD_PID=""
|
||||
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
cleanup() {
|
||||
local exit_code=$?
|
||||
if [[ -n "$DASHBOARD_PID" ]] && kill -0 "$DASHBOARD_PID" >/dev/null 2>&1; then
|
||||
kill "$DASHBOARD_PID" >/dev/null 2>&1 || true
|
||||
wait "$DASHBOARD_PID" >/dev/null 2>&1 || true
|
||||
fi
|
||||
if [[ -n "$BACKEND_PID" ]] && kill -0 "$BACKEND_PID" >/dev/null 2>&1; then
|
||||
kill "$BACKEND_PID" >/dev/null 2>&1 || true
|
||||
wait "$BACKEND_PID" >/dev/null 2>&1 || true
|
||||
fi
|
||||
exit "$exit_code"
|
||||
}
|
||||
|
||||
trap cleanup EXIT
|
||||
|
||||
wait_for_url() {
|
||||
local url="$1"
|
||||
local label="$2"
|
||||
local timeout="${3:-60}"
|
||||
local started_at
|
||||
started_at="$(date +%s)"
|
||||
while true; do
|
||||
if env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
|
||||
curl -fsS "$url" >/dev/null 2>&1; then
|
||||
return 0
|
||||
fi
|
||||
if (( "$(date +%s)" - started_at >= timeout )); then
|
||||
echo "Timed out waiting for ${label}: ${url}" >&2
|
||||
return 1
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
if [[ "$REUSE_SERVICES" != "1" ]]; then
|
||||
if ! env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
|
||||
curl -fsS "${BACKEND_URL}/api/webui/health" >/dev/null 2>&1; then
|
||||
(
|
||||
cd "$REPO_ROOT"
|
||||
WEBUI_HOST="$BACKEND_HOST" WEBUI_PORT="$BACKEND_PORT" "$PYTHON_BIN" "$BACKEND_SCRIPT"
|
||||
) >"$OUTPUT_DIR/backend.log" 2>&1 &
|
||||
BACKEND_PID="$!"
|
||||
wait_for_url "${BACKEND_URL}/api/webui/health" "MaiBot WebUI backend" 120
|
||||
fi
|
||||
|
||||
if ! env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
|
||||
curl -fsS "${DASHBOARD_URL}/auth" >/dev/null 2>&1; then
|
||||
(
|
||||
cd "$DASHBOARD_ROOT"
|
||||
npm run dev -- --host "$DASHBOARD_HOST" --port "$DASHBOARD_PORT"
|
||||
) >"$OUTPUT_DIR/dashboard.log" 2>&1 &
|
||||
DASHBOARD_PID="$!"
|
||||
wait_for_url "${DASHBOARD_URL}/auth" "dashboard dev server" 120
|
||||
fi
|
||||
fi
|
||||
|
||||
env -u ELECTRON_RUN_AS_NODE \
|
||||
MAIBOT_DASHBOARD_URL="$DASHBOARD_URL" \
|
||||
MAIBOT_UI_SNAPSHOT_DIR="$OUTPUT_DIR" \
|
||||
"$ELECTRON_BIN" "$DRIVER_SCRIPT"
|
||||
Reference in New Issue
Block a user