chore: import deployable mai-bot source tree

This commit is contained in:
2026-05-11 00:51:12 +00:00
parent 4813699b3e
commit 7a54015f94
1009 changed files with 312999 additions and 16 deletions

View File

@@ -0,0 +1,336 @@
from pathlib import Path
from typing import Any
from scipy import stats
import argparse
import csv
import json
import math
DEFAULT_LOG_DIR = Path("logs") / "maisaka_reply_effect"
DEFAULT_MANUAL_DIR = Path("logs") / "maisaka_reply_effect_manual"
METRIC_SPECS = [
("总分", "asi", "ASI 自动总分"),
("大项", "behavior_score", "行为满意度 B"),
("大项", "relational_score", "感知质量 R"),
("大项", "friction_score", "摩擦风险 F"),
("大项", "friction_quality_score", "低摩擦质量分"),
("行为子项", "behavior_signals.continue_2turns", "继续两轮"),
("行为子项", "behavior_signals.next_user_sentiment", "后续情绪"),
("行为子项", "behavior_signals.user_expansion", "用户展开"),
("行为子项", "behavior_signals.no_correction", "没有纠正"),
("行为子项", "behavior_signals.no_abort", "没有放弃"),
("rubric 子项", "rubric_scores.social_presence.normalized_score", "社交临场感"),
("rubric 子项", "rubric_scores.warmth.normalized_score", "温暖感"),
("rubric 子项", "rubric_scores.competence.normalized_score", "能力/有用性"),
("rubric 子项", "rubric_scores.appropriateness.normalized_score", "合适程度"),
("rubric 子项", "rubric_scores.uncanny_risk.normalized_score", "违和风险 judge"),
("摩擦子项", "friction_signals.explicit_negative", "明确负反馈"),
("摩擦子项", "friction_signals.repair_loop", "修复循环"),
("摩擦子项", "friction_signals.uncanny_risk", "违和风险"),
]
def normalize_name(value: str) -> str:
normalized = "".join(char if char.isalnum() or char in "._-" else "_" for char in str(value or "").strip())
normalized = normalized.strip("._")
return normalized or "unknown"
def load_json_file(file_path: Path) -> dict[str, Any]:
try:
payload = json.loads(file_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {}
return payload if isinstance(payload, dict) else {}
def to_float(value: Any) -> float | None:
if value in {None, ""}:
return None
try:
number = float(value)
except (TypeError, ValueError):
return None
if math.isnan(number) or math.isinf(number):
return None
return number
def get_nested(payload: dict[str, Any], dotted_path: str) -> Any:
current: Any = payload
for key in dotted_path.split("."):
if not isinstance(current, dict):
return None
current = current.get(key)
return current
def annotation_path(manual_dir: Path, chat_id: str, effect_id: str) -> Path:
return manual_dir / normalize_name(chat_id) / f"{normalize_name(effect_id)}.json"
def iter_records(
log_dir: Path,
manual_dir: Path,
*,
chat_id: str,
include_pending: bool,
) -> list[dict[str, Any]]:
records: list[dict[str, Any]] = []
if not log_dir.exists():
return records
chat_dirs = [log_dir / normalize_name(chat_id)] if chat_id else [path for path in log_dir.iterdir() if path.is_dir()]
for chat_dir in sorted(chat_dirs):
if not chat_dir.exists() or not chat_dir.is_dir():
continue
for record_file in sorted(chat_dir.glob("*.json")):
effect_record = load_json_file(record_file)
if not effect_record:
continue
if not include_pending and effect_record.get("status") != "finalized":
continue
effect_id = str(effect_record.get("effect_id") or record_file.stem)
manual_record = load_json_file(annotation_path(manual_dir, chat_dir.name, effect_id))
manual_score = to_float(manual_record.get("manual_score"))
if manual_score is None:
manual_score_5 = to_float(manual_record.get("manual_score_5"))
if manual_score_5 is not None:
manual_score = (manual_score_5 - 1) / 4 * 100
if manual_score is None:
continue
raw_scores = effect_record.get("scores") if isinstance(effect_record.get("scores"), dict) else {}
scores = dict(raw_scores)
friction_score = to_float(scores.get("friction_score"))
if friction_score is not None:
scores["friction_quality_score"] = 1 - friction_score
records.append(
{
"chat_id": chat_dir.name,
"effect_id": effect_id,
"manual_score": manual_score,
"manual_score_5": manual_record.get("manual_score_5"),
"scores": scores,
"status": effect_record.get("status"),
"created_at": effect_record.get("created_at"),
"record_file": str(record_file),
}
)
return records
def calculate_metric_stats(records: list[dict[str, Any]], metric_path: str, min_n: int) -> dict[str, Any]:
pairs: list[tuple[float, float]] = []
for record in records:
x_value = to_float(get_nested(record["scores"], metric_path))
y_value = to_float(record["manual_score"])
if x_value is None or y_value is None:
continue
pairs.append((x_value, y_value))
x_values = [pair[0] for pair in pairs]
y_values = [pair[1] for pair in pairs]
result: dict[str, Any] = {
"n": len(pairs),
"pearson_r": None,
"pearson_p": None,
"spearman_r": None,
"spearman_p": None,
"kendall_tau": None,
"kendall_p": None,
"note": "",
}
if len(pairs) < min_n:
result["note"] = f"样本数少于 {min_n}"
return result
if len(set(x_values)) < 2:
result["note"] = "自动评分没有变化,无法计算相关"
return result
if len(set(y_values)) < 2:
result["note"] = "人工评分没有变化,无法计算相关"
return result
pearson = stats.pearsonr(x_values, y_values)
spearman = stats.spearmanr(x_values, y_values)
kendall = stats.kendalltau(x_values, y_values)
result.update(
{
"pearson_r": round_float(pearson.statistic),
"pearson_p": round_float(pearson.pvalue),
"spearman_r": round_float(spearman.statistic),
"spearman_p": round_float(spearman.pvalue),
"kendall_tau": round_float(kendall.statistic),
"kendall_p": round_float(kendall.pvalue),
}
)
return result
def round_float(value: Any) -> float | None:
number = to_float(value)
if number is None:
return None
return round(number, 6)
def significance_label(p_value: float | None) -> str:
if p_value is None:
return ""
if p_value < 0.001:
return "***"
if p_value < 0.01:
return "**"
if p_value < 0.05:
return "*"
if p_value < 0.1:
return "."
return "ns"
def build_report(records: list[dict[str, Any]], min_n: int) -> list[dict[str, Any]]:
report: list[dict[str, Any]] = []
for group, metric_path, label in METRIC_SPECS:
metric_stats = calculate_metric_stats(records, metric_path, min_n)
report.append(
{
"group": group,
"metric": metric_path,
"label": label,
**metric_stats,
"pearson_sig": significance_label(metric_stats["pearson_p"]),
"spearman_sig": significance_label(metric_stats["spearman_p"]),
"kendall_sig": significance_label(metric_stats["kendall_p"]),
}
)
return report
def print_report(records: list[dict[str, Any]], report: list[dict[str, Any]]) -> None:
chats = sorted({record["chat_id"] for record in records})
print("\nMaisaka 回复效果评分相关性分析")
print("=" * 96)
print(f"已匹配人工评分记录数: {len(records)}")
print(f"聊天流数量: {len(chats)}")
if chats:
print(f"聊天流: {', '.join(chats[:8])}{' ...' if len(chats) > 8 else ''}")
print("人工分使用 manual_score若只有 manual_score_5则换算到 0-100 后参与计算。")
print("显著性: *** p<0.001, ** p<0.01, * p<0.05, . p<0.1, ns 不显著")
print("-" * 96)
header = (
f"{'分组':<14} {'指标':<34} {'n':>4} "
f"{'Pearson r':>10} {'p':>10} {'sig':>4} "
f"{'Spearman r':>11} {'p':>10} {'sig':>4} "
f"{'Kendall':>9} {'p':>10} {'说明'}"
)
print(header)
print("-" * 96)
for item in report:
print(
f"{item['group']:<14} "
f"{item['label']:<34} "
f"{item['n']:>4} "
f"{format_number(item['pearson_r']):>10} "
f"{format_number(item['pearson_p']):>10} "
f"{item['pearson_sig']:>4} "
f"{format_number(item['spearman_r']):>11} "
f"{format_number(item['spearman_p']):>10} "
f"{item['spearman_sig']:>4} "
f"{format_number(item['kendall_tau']):>9} "
f"{format_number(item['kendall_p']):>10} "
f"{item['note']}"
)
total = next((item for item in report if item["metric"] == "asi"), None)
if total:
print("-" * 96)
print(
"总分 ASI 与人工分的 Pearson 相关: "
f"r={format_number(total['pearson_r'])}, "
f"p={format_number(total['pearson_p'])}, "
f"显著性={total['pearson_sig'] or 'N/A'}"
)
def format_number(value: Any) -> str:
if value is None:
return "N/A"
number = to_float(value)
if number is None:
return "N/A"
if abs(number) < 0.000001:
return "0"
return f"{number:.4g}"
def write_csv(file_path: Path, report: list[dict[str, Any]]) -> None:
file_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = [
"group",
"metric",
"label",
"n",
"pearson_r",
"pearson_p",
"pearson_sig",
"spearman_r",
"spearman_p",
"spearman_sig",
"kendall_tau",
"kendall_p",
"kendall_sig",
"note",
]
with file_path.open("w", encoding="utf-8-sig", newline="") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(report)
def write_json(file_path: Path, records: list[dict[str, Any]], report: list[dict[str, Any]]) -> None:
file_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"matched_record_count": len(records),
"chat_count": len({record["chat_id"] for record in records}),
"report": report,
}
file_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser(description="分析 Maisaka 回复效果自动评分与人工评分的相关性和显著性。")
parser.add_argument("--log-dir", type=Path, default=DEFAULT_LOG_DIR, help="自动评分 JSON 目录")
parser.add_argument("--manual-dir", type=Path, default=DEFAULT_MANUAL_DIR, help="人工评分 JSON 目录")
parser.add_argument("--chat-id", default="", help="只分析某个 platform_type_id例如 qq_group_1028699246")
parser.add_argument("--include-pending", action="store_true", help="包含尚未 finalized 的记录")
parser.add_argument("--min-n", type=int, default=3, help="计算相关性需要的最小样本数,默认 3")
parser.add_argument("--csv", type=Path, default=None, help="把统计结果另存为 CSV")
parser.add_argument("--json", type=Path, default=None, help="把统计结果另存为 JSON")
args = parser.parse_args()
records = iter_records(
args.log_dir,
args.manual_dir,
chat_id=args.chat_id,
include_pending=args.include_pending,
)
report = build_report(records, max(2, args.min_n))
print_report(records, report)
if args.csv:
write_csv(args.csv, report)
print(f"\nCSV 已保存: {args.csv}")
if args.json:
write_json(args.json, records, report)
print(f"JSON 已保存: {args.json}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,323 @@
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import DefaultDict
import csv
import json
import re
import sqlite3
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "MaiBot.db"
@dataclass(frozen=True)
class ToolUsageRow:
chat_id: str
tool_name: str
count: int
chat_total: int
percent_in_chat: float
percent_in_all: float
def parse_datetime_filter(value: str | None) -> str | None:
if value is None:
return None
normalized_value = value.strip()
if not normalized_value:
return None
for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"):
try:
parsed = datetime.strptime(normalized_value, fmt)
except ValueError:
continue
return parsed.strftime("%Y-%m-%d %H:%M:%S")
raise ValueError(f"无法解析时间: {value!r},请使用 YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
def parse_recent_filter(value: str | None) -> str | None:
if value is None:
return None
normalized_value = value.strip().lower()
if not normalized_value:
return None
match = re.fullmatch(r"(\d+(?:\.\d+)?)([mhdw])", normalized_value)
if match is None:
raise ValueError(f"无法解析最近时间: {value!r},请使用 30m、24h、7d 或 2w")
amount = float(match.group(1))
unit = match.group(2)
if amount <= 0:
raise ValueError(f"最近时间必须大于 0: {value!r}")
if unit == "m":
delta = timedelta(minutes=amount)
elif unit == "h":
delta = timedelta(hours=amount)
elif unit == "d":
delta = timedelta(days=amount)
else:
delta = timedelta(weeks=amount)
return (datetime.now() - delta).strftime("%Y-%m-%d %H:%M:%S")
def connect_readonly(db_path: Path) -> sqlite3.Connection:
if not db_path.exists():
raise FileNotFoundError(f"数据库文件不存在: {db_path}")
database_uri = f"file:{db_path.as_posix()}?mode=ro"
connection = sqlite3.connect(database_uri, uri=True)
connection.row_factory = sqlite3.Row
connection.execute("PRAGMA busy_timeout=5000")
return connection
def fetch_tool_counts(
db_path: Path,
since: str | None,
until: str | None,
include_empty_chat_id: bool,
include_empty_tool_name: bool,
) -> list[tuple[str, str, int]]:
where_clauses: list[str] = []
params: list[str] = []
if since is not None:
where_clauses.append("timestamp >= ?")
params.append(since)
if until is not None:
where_clauses.append("timestamp < ?")
params.append(until)
if not include_empty_chat_id:
where_clauses.append("COALESCE(session_id, '') != ''")
if not include_empty_tool_name:
where_clauses.append("COALESCE(tool_name, '') != ''")
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
query = f"""
SELECT
COALESCE(session_id, '') AS chat_id,
COALESCE(tool_name, '') AS tool_name,
COUNT(*) AS usage_count
FROM tool_records
{where_sql}
GROUP BY COALESCE(session_id, ''), COALESCE(tool_name, '')
ORDER BY COALESCE(session_id, ''), usage_count DESC, COALESCE(tool_name, '')
"""
with connect_readonly(db_path) as connection:
rows = connection.execute(query, params).fetchall()
return [(str(row["chat_id"]), str(row["tool_name"]), int(row["usage_count"])) for row in rows]
def build_usage_rows(
counts: list[tuple[str, str, int]],
min_chat_total: int,
top_tools_per_chat: int | None,
) -> list[ToolUsageRow]:
chat_totals: DefaultDict[str, int] = defaultdict(int)
for chat_id, _tool_name, count in counts:
chat_totals[chat_id] += count
all_total = sum(chat_totals.values())
rows: list[ToolUsageRow] = []
emitted_per_chat: DefaultDict[str, int] = defaultdict(int)
sorted_counts = sorted(counts, key=lambda item: (item[0], -item[2], item[1]))
for chat_id, tool_name, count in sorted_counts:
chat_total = chat_totals[chat_id]
if chat_total < min_chat_total:
continue
if top_tools_per_chat is not None and emitted_per_chat[chat_id] >= top_tools_per_chat:
continue
emitted_per_chat[chat_id] += 1
rows.append(
ToolUsageRow(
chat_id=chat_id,
tool_name=tool_name,
count=count,
chat_total=chat_total,
percent_in_chat=count / chat_total * 100 if chat_total else 0.0,
percent_in_all=count / all_total * 100 if all_total else 0.0,
)
)
return rows
def build_overall_rows(counts: list[tuple[str, str, int]], min_chat_total: int) -> list[tuple[str, int, float]]:
chat_totals: DefaultDict[str, int] = defaultdict(int)
for chat_id, _tool_name, count in counts:
chat_totals[chat_id] += count
tool_counts: DefaultDict[str, int] = defaultdict(int)
for chat_id, tool_name, count in counts:
if chat_totals[chat_id] < min_chat_total:
continue
tool_counts[tool_name] += count
total = sum(tool_counts.values())
sorted_items = sorted(tool_counts.items(), key=lambda item: (-item[1], item[0]))
return [(tool_name, count, count / total * 100 if total else 0.0) for tool_name, count in sorted_items]
def print_overall_block(overall_rows: list[tuple[str, int, float]]) -> None:
print("全部统计")
total = sum(count for _tool_name, count, _percent in overall_rows)
print(f"tool_total: {total}")
if not overall_rows:
print(" 无工具调用记录")
return
tool_width = max(len("tool"), *(len(tool_name) for tool_name, _count, _percent in overall_rows))
count_width = max(len("count"), *(len(str(count)) for _tool_name, count, _percent in overall_rows))
percent_width = max(len("全局占比"), *(len(f"{percent:.2f}%") for _tool_name, _count, percent in overall_rows))
print(f" {'tool':<{tool_width}} {'count':>{count_width}} {'全局占比':>{percent_width}}")
print(f" {'-' * tool_width} {'-' * count_width} {'-' * percent_width}")
for tool_name, count, percent in overall_rows:
print(f" {tool_name:<{tool_width}} {count:>{count_width}} {percent:>{percent_width - 1}.2f}%")
def print_markdown(rows: list[ToolUsageRow], overall_rows: list[tuple[str, int, float]]) -> None:
print_overall_block(overall_rows)
if rows:
print()
grouped_rows: DefaultDict[str, list[ToolUsageRow]] = defaultdict(list)
for row in rows:
grouped_rows[row.chat_id].append(row)
first_group = True
for chat_id in sorted(grouped_rows):
chat_rows = grouped_rows[chat_id]
if not chat_rows:
continue
if not first_group:
print()
first_group = False
chat_total = chat_rows[0].chat_total
print(f"chat_id: {chat_id}")
print(f"tool_total: {chat_total}")
tool_width = max(len("tool"), *(len(row.tool_name) for row in chat_rows))
count_width = max(len("count"), *(len(str(row.count)) for row in chat_rows))
chat_percent_width = max(len("chat内占比"), *(len(f"{row.percent_in_chat:.2f}%") for row in chat_rows))
all_percent_width = max(len("全局占比"), *(len(f"{row.percent_in_all:.2f}%") for row in chat_rows))
header = (
f" {'tool':<{tool_width}} "
f"{'count':>{count_width}} "
f"{'chat内占比':>{chat_percent_width}} "
f"{'全局占比':>{all_percent_width}}"
)
print(header)
print(
f" {'-' * tool_width} "
f"{'-' * count_width} "
f"{'-' * chat_percent_width} "
f"{'-' * all_percent_width}"
)
for row in chat_rows:
print(
f" {row.tool_name:<{tool_width}} "
f"{row.count:>{count_width}} "
f"{row.percent_in_chat:>{chat_percent_width - 1}.2f}% "
f"{row.percent_in_all:>{all_percent_width - 1}.2f}%"
)
def print_json(rows: list[ToolUsageRow]) -> None:
payload = [
{
"chat_id": row.chat_id,
"tool_name": row.tool_name,
"count": row.count,
"chat_total": row.chat_total,
"percent_in_chat": round(row.percent_in_chat, 4),
"percent_in_all": round(row.percent_in_all, 4),
}
for row in rows
]
print(json.dumps(payload, ensure_ascii=False, indent=2))
def print_csv(rows: list[ToolUsageRow]) -> None:
writer = csv.writer(sys.stdout)
writer.writerow(["chat_id", "tool_name", "count", "chat_total", "percent_in_chat", "percent_in_all"])
for row in rows:
writer.writerow(
[
row.chat_id,
row.tool_name,
row.count,
row.chat_total,
f"{row.percent_in_chat:.4f}",
f"{row.percent_in_all:.4f}",
]
)
def parse_args() -> Namespace:
parser = ArgumentParser(description="统计不同 chat_id 的工具使用次数和占比。")
parser.add_argument("--db", type=Path, default=DEFAULT_DB_PATH, help=f"数据库路径,默认: {DEFAULT_DB_PATH}")
parser.add_argument("--recent", help="统计最近多久的记录,例如: 30m、24h、7d、2w如果同时指定 --since则优先使用 --since")
parser.add_argument("--since", help="仅统计此时间之后的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
parser.add_argument("--until", help="仅统计此时间之前的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS")
parser.add_argument("--min-chat-total", type=int, default=1, help="只显示工具调用总数不低于该值的 chat_id")
parser.add_argument("--top-tools", type=int, help="每个 chat_id 最多显示前 N 个工具")
parser.add_argument("--format", choices=("markdown", "json", "csv"), default="markdown", help="输出格式markdown 为按 chat_id 分块的终端表")
parser.add_argument("--include-empty-chat-id", action="store_true", help="包含 chat_id 为空的记录")
parser.add_argument("--include-empty-tool-name", action="store_true", help="包含 tool_name 为空的记录")
return parser.parse_args()
def main() -> None:
args = parse_args()
since = parse_datetime_filter(args.since) or parse_recent_filter(args.recent)
until = parse_datetime_filter(args.until)
min_chat_total = max(1, int(args.min_chat_total))
top_tools = args.top_tools if args.top_tools is None else max(1, int(args.top_tools))
counts = fetch_tool_counts(
db_path=args.db.resolve(),
since=since,
until=until,
include_empty_chat_id=args.include_empty_chat_id,
include_empty_tool_name=args.include_empty_tool_name,
)
rows = build_usage_rows(
counts=counts,
min_chat_total=min_chat_total,
top_tools_per_chat=top_tools,
)
if args.format == "json":
print_json(rows)
elif args.format == "csv":
print_csv(rows)
else:
overall_rows = build_overall_rows(counts, min_chat_total=min_chat_total)
print_markdown(rows, overall_rows)
if __name__ == "__main__":
main()

380
scripts/build_io_pairs.py Normal file
View File

@@ -0,0 +1,380 @@
import argparse
import json
import random
import re
import sys
import os
from datetime import datetime
from typing import Dict, Iterable, List, Optional, Tuple
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import build_readable_messages
# 确保可从任意工作目录运行:将项目根目录加入 sys.pathscripts 的上一级)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60
def clean_output_text(text: str) -> str:
"""
清理输出文本,移除表情包和回复内容
- 移除 [表情包:...] 格式的内容
- 移除 [回复...] 格式的内容
"""
if not text:
return text
# 移除表情包内容:[表情包:...]
text = re.sub(r"\[表情包:[^\]]*\]", "", text)
# 移除回复内容:[回复...],说:... 的完整模式
text = re.sub(r"\[回复[^\]]*\],说:[^@]*@[^:]*:", "", text)
# 清理多余的空格和换行
text = re.sub(r"\s+", " ", text).strip()
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err: Optional[Exception] = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e: # noqa: BLE001
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def fetch_messages_between(
start_ts: float,
end_ts: float,
platform: Optional[str] = None,
) -> List[DatabaseMessages]:
"""使用 find_messages 获取指定区间的消息,可选按 chat_info_platform 过滤。按时间升序返回。"""
filter_query: Dict[str, object] = {"time": {"$gt": start_ts, "$lt": end_ts}}
if platform:
filter_query["chat_info_platform"] = platform
# 当 limit==0 时sort 生效,这里按时间升序
return find_messages(message_filter=filter_query, sort=[("time", 1)], limit=0)
def group_by_chat(messages: Iterable[DatabaseMessages]) -> Dict[str, List[DatabaseMessages]]:
groups: Dict[str, List[DatabaseMessages]] = {}
for msg in messages:
groups.setdefault(msg.chat_id, []).append(msg)
# 保证每个分组内按时间升序
for _chat_id, msgs in groups.items():
msgs.sort(key=lambda m: m.time or 0)
return groups
def _merge_bucket_to_message(bucket: List[DatabaseMessages]) -> DatabaseMessages:
"""
将相邻、同一 user_id 且 5 分钟内的消息 bucket 合并为一条。
processed_plain_text 合并(以换行连接),其余字段取最新一条(时间最大)。
"""
if not bucket:
raise ValueError("bucket 为空,无法合并")
latest = bucket[-1]
merged_texts: List[str] = []
for m in bucket:
text = m.processed_plain_text or ""
if text:
merged_texts.append(text)
merged = DatabaseMessages(
# 其他信息采用最新消息
message_id=latest.message_id,
time=latest.time,
chat_id=latest.chat_id,
reply_to=latest.reply_to,
is_mentioned=latest.is_mentioned,
is_at=latest.is_at,
reply_probability_boost=latest.reply_probability_boost,
processed_plain_text="\n".join(merged_texts) if merged_texts else latest.processed_plain_text,
priority_mode=latest.priority_mode,
priority_info=latest.priority_info,
additional_config=latest.additional_config,
is_emoji=latest.is_emoji,
is_picid=latest.is_picid,
is_command=latest.is_command,
is_notify=latest.is_notify,
selected_expressions=latest.selected_expressions,
user_id=latest.user_info.user_id,
user_nickname=latest.user_info.user_nickname,
user_cardname=latest.user_info.user_cardname,
user_platform=latest.user_info.platform,
chat_info_group_id=(latest.group_info.group_id if latest.group_info else None),
chat_info_group_name=(latest.group_info.group_name if latest.group_info else None),
chat_info_group_platform=(latest.group_info.group_platform if latest.group_info else None),
chat_info_user_id=latest.chat_info.user_info.user_id,
chat_info_user_nickname=latest.chat_info.user_info.user_nickname,
chat_info_user_cardname=latest.chat_info.user_info.user_cardname,
chat_info_user_platform=latest.chat_info.user_info.platform,
chat_info_stream_id=latest.chat_info.stream_id,
chat_info_platform=latest.chat_info.platform,
chat_info_create_time=latest.chat_info.create_time,
chat_info_last_active_time=latest.chat_info.last_active_time,
)
return merged
def merge_adjacent_same_user(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""按 5 分钟窗口合并相邻同 user_id 的消息。输入需按时间升序。"""
if not messages:
return []
merged: List[DatabaseMessages] = []
bucket: List[DatabaseMessages] = []
def flush_bucket() -> None:
nonlocal bucket
if bucket:
merged.append(_merge_bucket_to_message(bucket))
bucket = []
for msg in messages:
if not bucket:
bucket = [msg]
continue
last = bucket[-1]
same_user = msg.user_info.user_id == last.user_info.user_id
close_enough = (msg.time or 0) - (last.time or 0) <= SECONDS_5_MINUTES
if same_user and close_enough:
bucket.append(msg)
else:
flush_bucket()
bucket = [msg]
flush_bucket()
return merged
def build_pairs_for_chat(
original_messages: List[DatabaseMessages],
merged_messages: List[DatabaseMessages],
min_ctx: int,
max_ctx: int,
target_user_id: Optional[str] = None,
) -> List[Tuple[str, str, str]]:
"""
对每条合并后的消息作为 output从其前面取 20-30 条(可配置)的原始消息作为 input。
input 使用原始未合并的消息构建上下文。
output 使用合并后消息的 processed_plain_text。
如果指定了 target_user_id则只处理该用户的消息作为 output。
"""
pairs: List[Tuple[str, str, str]] = []
n_merged = len(merged_messages)
n_original = len(original_messages)
if n_merged == 0 or n_original == 0:
return pairs
# 为每个合并后的消息找到对应的原始消息位置
merged_to_original_map = {}
original_idx = 0
for merged_idx, merged_msg in enumerate(merged_messages):
# 找到这个合并消息对应的第一个原始消息
while original_idx < n_original and original_messages[original_idx].time < merged_msg.time:
original_idx += 1
# 如果找到了时间匹配的原始消息,建立映射
if original_idx < n_original and original_messages[original_idx].time == merged_msg.time:
merged_to_original_map[merged_idx] = original_idx
for merged_idx in range(n_merged):
merged_msg = merged_messages[merged_idx]
# 如果指定了 target_user_id只处理该用户的消息作为 output
if target_user_id and merged_msg.user_info.user_id != target_user_id:
continue
# 找到对应的原始消息位置
if merged_idx not in merged_to_original_map:
continue
original_idx = merged_to_original_map[merged_idx]
# 选择上下文窗口大小
window = random.randint(min_ctx, max_ctx) if max_ctx > min_ctx else min_ctx
start = max(0, original_idx - window)
context_msgs = original_messages[start:original_idx]
# 使用原始未合并消息构建 input
input_str = build_readable_messages(
messages=context_msgs,
timestamp_mode="normal_no_YMD",
show_actions=False,
show_pic=True,
)
# 输出取合并后消息的 processed_plain_text 并清理表情包和回复内容
output_text = merged_msg.processed_plain_text or ""
output_text = clean_output_text(output_text)
output_id = merged_msg.message_id or ""
pairs.append((input_str, output_text, output_id))
return pairs
def build_pairs(
start_ts: float,
end_ts: float,
platform: Optional[str],
user_id: Optional[str],
min_ctx: int,
max_ctx: int,
) -> List[Tuple[str, str, str]]:
# 获取所有消息不按user_id过滤这样input上下文可以包含所有用户的消息
messages = fetch_messages_between(start_ts, end_ts, platform)
groups = group_by_chat(messages)
all_pairs: List[Tuple[str, str, str]] = []
for _chat_id, msgs in groups.items(): # noqa: F841 - chat_id 未直接使用
# 对消息进行合并用于output
merged = merge_adjacent_same_user(msgs)
# 传递原始消息和合并后消息input使用原始消息output使用合并后消息
pairs = build_pairs_for_chat(msgs, merged, min_ctx, max_ctx, user_id)
all_pairs.extend(pairs)
return all_pairs
def main(argv: Optional[List[str]] = None) -> int:
# 若未提供参数,则进入交互模式
if argv is None:
argv = sys.argv[1:]
if len(argv) == 0:
return run_interactive()
parser = argparse.ArgumentParser(description="构建 (input_str, output_str, message_id) 列表支持按用户ID筛选消息")
parser.add_argument("start", help="起始时间,如 2025-09-28 00:00:00")
parser.add_argument("end", help="结束时间,如 2025-09-29 00:00:00")
parser.add_argument("--platform", default=None, help="仅选择 chat_info_platform 为该值的消息")
parser.add_argument("--user_id", default=None, help="仅选择指定 user_id 的消息")
parser.add_argument("--min_ctx", type=int, default=20, help="输入上下文的最少条数默认20")
parser.add_argument("--max_ctx", type=int, default=30, help="输入上下文的最多条数默认30")
parser.add_argument(
"--output",
default=None,
help="输出保存路径,支持 .jsonl每行 {input, output}若不指定则打印到stdout",
)
args = parser.parse_args(argv)
start_ts = parse_datetime_to_timestamp(args.start)
end_ts = parse_datetime_to_timestamp(args.end)
if end_ts <= start_ts:
raise ValueError("结束时间必须大于起始时间")
if args.max_ctx < args.min_ctx:
raise ValueError("max_ctx 不能小于 min_ctx")
pairs = build_pairs(start_ts, end_ts, args.platform, args.user_id, args.min_ctx, args.max_ctx)
if args.output:
# 保存为 JSONL每行一个 {input, output, message_id}
with open(args.output, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {args.output}")
else:
# 打印到 stdout
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
return 0
def _prompt_with_default(prompt_text: str, default: Optional[str]) -> str:
suffix = f"[{default}]" if default not in (None, "") else ""
value = input(f"{prompt_text}{' ' + suffix if suffix else ''}: ").strip()
if value == "" and default is not None:
return default
return value
def run_interactive() -> int:
print("进入交互模式直接回车采用默认值。时间格式例如2025-09-28 00:00:00 或 2025-09-28")
start_str = _prompt_with_default("请输入起始时间", None)
end_str = _prompt_with_default("请输入结束时间", None)
platform = _prompt_with_default("平台(可留空表示不限)", "")
user_id = _prompt_with_default("用户ID可留空表示不限", "")
try:
min_ctx = int(_prompt_with_default("输入上下文最少条数", "20"))
max_ctx = int(_prompt_with_default("输入上下文最多条数", "30"))
except Exception:
print("上下文条数输入有误,使用默认 20/30")
min_ctx, max_ctx = 20, 30
output_path = _prompt_with_default("输出路径(.jsonl可留空打印到控制台", "")
if not start_str or not end_str:
print("必须提供起始与结束时间。")
return 2
try:
start_ts = parse_datetime_to_timestamp(start_str)
end_ts = parse_datetime_to_timestamp(end_str)
except Exception as e: # noqa: BLE001
print(f"时间解析失败:{e}")
return 2
if end_ts <= start_ts:
print("结束时间必须大于起始时间。")
return 2
if max_ctx < min_ctx:
print("最多条数不能小于最少条数。")
return 2
platform_val = platform if platform != "" else None
user_id_val = user_id if user_id != "" else None
pairs = build_pairs(start_ts, end_ts, platform_val, user_id_val, min_ctx, max_ctx)
if output_path:
with open(output_path, "w", encoding="utf-8") as f:
for input_str, output_str, message_id in pairs:
obj = {"input": input_str, "output": output_str, "message_id": message_id}
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
print(f"已保存 {len(pairs)} 条到 {output_path}")
else:
for input_str, output_str, message_id in pairs:
print(json.dumps({"input": input_str, "output": output_str, "message_id": message_id}, ensure_ascii=False))
print(f"总计 {len(pairs)} 条。")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,553 @@
"""
表达方式按count分组的LLM评估和统计分析脚本
功能:
1. 随机选择50条表达至少要有20条count>1的项目然后进行LLM评估
2. 比较不同count之间的LLM评估合格率是否有显著差异
- 首先每个count分开比较
- 然后比较count为1和count大于1的两种
"""
import asyncio
import random
import json
import sys
import os
import re
from typing import List, Dict, Set, Tuple
from datetime import datetime
from collections import defaultdict
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
logger = get_logger("expression_evaluator_count_analysis_llm")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(COUNT_ANALYSIS_FILE):
return [], set()
try:
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("evaluation_results", [])
# 使用 (situation, style) 作为唯一标识
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
logger.info(f"已加载 {len(results)} 条已有评估结果")
return results, evaluated_pairs
except Exception as e:
logger.error(f"加载已有评估结果失败: {e}")
return [], set()
def save_results(evaluation_results: List[Dict]):
"""
保存评估结果到文件
Args:
evaluation_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results),
"evaluation_results": evaluation_results,
}
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
except Exception as e:
logger.error(f"保存评估结果失败: {e}")
print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
"""
选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目
Args:
evaluated_pairs: 已评估的项目集合,用于避免重复
Returns:
选中的表达方式列表
"""
if evaluated_pairs is None:
evaluated_pairs = set()
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.warning("所有项目都已评估完成")
return []
# 按count分组
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}")
# 选择所有count>1的项目
selected_count_gt1 = count_gt1.copy()
# 选择count=1的项目数量为count>1数量的2倍
count_gt1_count = len(selected_count_gt1)
count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed:
logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序
logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景:{situation}
表达方式或言语风格:{style}
请从以下方面进行评估:
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
2. 允许部分语法错误或口头化或缺省出现
3. 表达方式不能太过特指,需要具有泛用性
4. 一般不涉及具体的人名或名称
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": expression.situation,
"style": expression.style,
"count": expression.count,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm",
"evaluated_at": datetime.now().isoformat(),
}
def perform_statistical_analysis(evaluation_results: List[Dict]):
"""
对评估结果进行统计分析
Args:
evaluation_results: 评估结果列表
"""
if not evaluation_results:
print("\n没有评估结果可供分析")
return
print("\n" + "=" * 60)
print("统计分析结果")
print("=" * 60)
# 按count分组统计
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
count_groups[count]["total"] += 1
if suitable:
count_groups[count]["suitable"] += 1
else:
count_groups[count]["unsuitable"] += 1
# 显示每个count的统计
print("\n【按count分组统计】")
print("-" * 60)
for count in sorted(count_groups.keys()):
group = count_groups[count]
total = group["total"]
suitable = group["suitable"]
unsuitable = group["unsuitable"]
pass_rate = (suitable / total * 100) if total > 0 else 0
print(f"Count = {count}:")
print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print()
# 比较count=1和count>1
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
if count == 1:
count_eq1_group["total"] += 1
if suitable:
count_eq1_group["suitable"] += 1
else:
count_eq1_group["unsuitable"] += 1
else:
count_gt1_group["total"] += 1
if suitable:
count_gt1_group["suitable"] += 1
else:
count_gt1_group["unsuitable"] += 1
print("\n【Count=1 vs Count>1 对比】")
print("-" * 60)
eq1_total = count_eq1_group["total"]
eq1_suitable = count_eq1_group["suitable"]
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
gt1_total = count_gt1_group["total"]
gt1_suitable = count_gt1_group["suitable"]
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
print("Count = 1:")
print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print()
print("Count > 1:")
print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print()
# 进行卡方检验简化版使用2x2列联表
if eq1_total > 0 and gt1_total > 0:
print("【统计显著性检验】")
print("-" * 60)
# 构建2x2列联表
# 通过 不通过
# count=1 a b
# count>1 c d
a = eq1_suitable
b = eq1_total - eq1_suitable
c = gt1_suitable
d = gt1_total - gt1_suitable
# 计算卡方统计量简化版使用Pearson卡方检验
n = eq1_total + gt1_total
if n > 0:
# 期望频数
e_a = (eq1_total * (a + c)) / n
e_b = (eq1_total * (b + d)) / n
e_c = (gt1_total * (a + c)) / n
e_d = (gt1_total * (b + d)) / n
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5
min_expected = min(e_a, e_b, e_c, e_d)
if min_expected < 5:
print("警告期望频数小于5卡方检验可能不准确")
print("建议使用Fisher精确检验")
# 计算卡方值
chi_square = 0
if e_a > 0:
chi_square += ((a - e_a) ** 2) / e_a
if e_b > 0:
chi_square += ((b - e_b) ** 2) / e_b
if e_c > 0:
chi_square += ((c - e_c) ** 2) / e_c
if e_d > 0:
chi_square += ((d - e_d) ** 2) / e_d
# 自由度 = (行数-1) * (列数-1) = 1
df = 1
# 临界值(α=0.05
chi_square_critical_005 = 3.841
chi_square_critical_001 = 6.635
print(f"卡方统计量: {chi_square:.4f}")
print(f"自由度: {df}")
print(f"临界值 (α=0.05): {chi_square_critical_005}")
print(f"临界值 (α=0.01): {chi_square_critical_001}")
if chi_square >= chi_square_critical_001:
print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01")
elif chi_square >= chi_square_critical_005:
print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05")
else:
print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05")
# 计算差异大小
diff = abs(eq1_pass_rate - gt1_pass_rate)
print(f"\n合格率差异: {diff:.2f}%")
if diff > 10:
print("差异较大(>10%")
elif diff > 5:
print("差异中等5-10%")
else:
print("差异较小(<5%")
else:
print("数据不足,无法进行统计检验")
else:
print("数据不足无法进行count=1和count>1的对比分析")
# 保存统计分析结果
analysis_result = {
"analysis_time": datetime.now().isoformat(),
"count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group,
"count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results),
}
try:
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
with open(analysis_file, "w", encoding="utf-8") as f:
json.dump(analysis_result, f, ensure_ascii=False, indent=2)
print(f"\n✓ 统计分析结果已保存到: {analysis_file}")
except Exception as e:
logger.error(f"保存统计分析结果失败: {e}")
async def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式按count分组的LLM评估和统计分析")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
evaluation_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
# 检查是否需要继续评估检查是否还有未评估的count>1项目
# 先查询未评估的count>1项目数量
try:
all_expressions = list(Expression.select())
unevaluated_count_gt1 = [
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
]
has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e:
logger.error(f"查询未评估项目失败: {e}")
has_unevaluated = False
if has_unevaluated:
print("\n" + "=" * 60)
print("开始LLM评估")
print("=" * 60)
print("评估结果会自动保存到文件\n")
# 创建LLM实例
print("创建LLM实例...")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm",
)
print("✓ LLM实例创建成功\n")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}")
db.close()
return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
if not expressions:
print("\n没有可评估的项目")
else:
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}")
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n")
batch_results = []
for i, expression in enumerate(expressions, 1):
print(f"LLM评估进度: {i}/{len(expressions)}")
print(f" Situation: {expression.situation}")
print(f" Style: {expression.style}")
print(f" Count: {expression.count}")
llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get("error"):
print(f" 错误: {llm_result['error']}")
print()
batch_results.append(llm_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
# 添加延迟以避免API限流
await asyncio.sleep(0.3)
# 将当前批次结果添加到总结果中
evaluation_results.extend(batch_results)
# 保存结果
save_results(evaluation_results)
else:
print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果")
# 进行统计分析
if len(evaluation_results) > 0:
perform_statistical_analysis(evaluation_results)
else:
print("\n没有评估结果可供分析")
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,535 @@
"""
表达方式LLM评估脚本
功能:
1. 读取已保存的人工评估结果(作为效标)
2. 使用LLM对相同项目进行评估
3. 对比人工评估和LLM评估的结果输出分析报告
"""
import asyncio
import argparse
import json
import random
import sys
import os
import glob
from typing import List, Dict, Set, Tuple
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_llm")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def load_manual_results() -> List[Dict]:
"""
加载人工评估结果自动读取temp目录下所有JSON文件并合并
Returns:
人工评估结果列表(已去重)
"""
if not os.path.exists(TEMP_DIR):
logger.error(f"未找到temp目录: {TEMP_DIR}")
print("\n✗ 错误未找到temp目录")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
print("\n✗ 错误temp目录下未找到JSON文件")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
logger.info(f"找到 {len(json_files)} 个JSON文件")
print(f"\n找到 {len(json_files)} 个JSON文件:")
for json_file in json_files:
print(f" - {os.path.basename(json_file)}")
# 读取并合并所有JSON文件
all_results = []
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
for json_file in json_files:
try:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 去重:使用(situation, style)作为唯一标识
for result in results:
if "situation" not in result or "style" not in result:
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
continue
pair = (result["situation"], result["style"])
if pair not in seen_pairs:
seen_pairs.add(pair)
all_results.append(result)
logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果")
except Exception as e:
logger.error(f"加载文件 {json_file} 失败: {e}")
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
continue
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
return all_results
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景:{situation}
表达方式或言语风格:{style}
请从以下方面进行评估:
1. 表达方式或言语风格 是否与使用条件或使用情景 匹配
2. 允许部分语法错误或口头化或缺省出现
3. 表达方式不能太过特指,需要具有泛用性
4. 一般不涉及具体的人名或名称
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": situation,
"style": style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm",
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称(用于标识)
Returns:
对比分析结果字典
"""
# 按(situation, style)建立映射
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
true_negatives += 1
elif not manual_suitable and llm_suitable:
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
# 计算经过LLM删除后剩余项目中的不合适率
# 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目
# 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
# 两者百分比相减评估LLM评定修正后的不合适率是否有降低
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
"matched": matched,
"accuracy": accuracy,
"accuracy_above_random": accuracy_above_random,
"accuracy_improvement_ratio": accuracy_improvement_ratio,
"true_positives": true_positives,
"true_negatives": true_negatives,
"false_positives": false_positives,
"false_negatives": false_negatives,
"precision": precision,
"recall": recall,
"f1_score": f1_score,
"specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference,
}
async def main(count: int | None = None):
"""
主函数
Args:
count: 随机选取的数据条数如果为None则使用全部数据
"""
logger.info("=" * 60)
logger.info("开始表达方式LLM评估")
logger.info("=" * 60)
# 1. 加载人工评估结果
print("\n步骤1: 加载人工评估结果")
manual_results = load_manual_results()
if not manual_results:
return
print(f"成功加载 {len(manual_results)} 条人工评估结果")
# 如果指定了数量,随机选择指定数量的数据
if count is not None:
if count <= 0:
print(f"\n✗ 错误指定的数量必须大于0当前值: {count}")
return
if count > len(manual_results):
print(f"\n⚠ 警告:指定的数量 ({count}) 大于可用数据量 ({len(manual_results)}),将使用全部数据")
else:
random.seed() # 使用系统时间作为随机种子
manual_results = random.sample(manual_results, count)
print(f"随机选取 {len(manual_results)} 条数据进行评估")
# 验证数据完整性
valid_manual_results = []
for r in manual_results:
if "situation" in r and "style" in r:
valid_manual_results.append(r)
else:
logger.warning(f"跳过无效数据: {r}")
if len(valid_manual_results) != len(manual_results):
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
print(f"有效数据: {len(valid_manual_results)}")
# 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例")
try:
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
print("\n步骤3: 开始LLM评估")
llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
# 5.1 输出FP项目人工评估不通过但LLM误判为通过
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
fp_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
print(f"--- [{idx}] ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 5.2 输出FN项目人工评估通过但LLM误判为不通过
print("\n" + "=" * 60)
print("人工评估通过但LLM误判为不通过的项目FN - False Negative")
print("=" * 60)
fn_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
for idx, item in enumerate(fn_items, 1):
print(f"--- [{idx}] ---")
print(f"Situation: {item['situation']}")
print(f"Style: {item['style']}")
print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)")
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过")
# 6. 对比分析并输出结果
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
# 详细评估结果(核心指标优先)
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
# print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
print(f" 准确率: {comparison['accuracy']:.2f}% (整体判断正确率)")
print(f" 精确率: {comparison['precision']:.2f}% (判断为合适的项目中,实际合适的比例)")
print(f" F1分数: {comparison['f1_score']:.2f} (精确率和召回率的调和平均)")
print(f" 匹配数: {comparison['matched']}/{comparison['total']}")
print()
print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
# print(f" - 含义: {'LLM删除后剩余项目中的不合适率降低了' if comparison['rate_difference'] > 0 else 'LLM删除后剩余项目中的不合适率反而升高了' if comparison['rate_difference'] < 0 else '两者相等'} ({'✓ LLM删除有效' if comparison['rate_difference'] > 0 else '✗ LLM删除效果不佳' if comparison['rate_difference'] < 0 else '效果相同'})")
# print()
print(" 【分类统计】")
print(f" TP (正确识别为合适): {comparison['true_positives']}")
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 7. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
f,
ensure_ascii=False,
indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="表达方式LLM评估脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
""",
)
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
args = parser.parse_args()
asyncio.run(main(count=args.count))

View File

@@ -0,0 +1,275 @@
"""
表达方式人工评估脚本
功能:
1. 不停随机抽取项目(不重复)进行人工评估
2. 将结果保存到 temp 文件夹下的 JSON 文件,作为效标(标准答案)
3. 支持继续评估(从已有文件中读取已评估的项目,避免重复)
"""
import random
import json
import sys
import os
from typing import List, Dict, Set, Tuple
from datetime import datetime
# 添加项目根目录到路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, project_root)
from src.common.database.database_model import Expression # noqa: E402
from src.common.database.database import db # noqa: E402
from src.common.logger import get_logger # noqa: E402
logger = get_logger("expression_evaluator_manual")
# 评估结果文件路径
TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(MANUAL_EVAL_FILE):
return [], set()
try:
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 使用 (situation, style) 作为唯一标识
evaluated_pairs = {(r["situation"], r["style"]) for r in results if "situation" in r and "style" in r}
logger.info(f"已加载 {len(results)} 条已有评估结果")
return results, evaluated_pairs
except Exception as e:
logger.error(f"加载已有评估结果失败: {e}")
return [], set()
def save_results(manual_results: List[Dict]):
"""
保存评估结果到文件
Args:
manual_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(manual_results),
"manual_results": manual_results,
}
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
except Exception as e:
logger.error(f"保存评估结果失败: {e}")
print(f"\n✗ 保存评估结果失败: {e}")
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
"""
获取未评估的表达方式
Args:
evaluated_pairs: 已评估的项目(situation, style)元组集合
batch_size: 每次获取的数量
Returns:
未评估的表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.info("所有项目都已评估完成")
return []
# 如果未评估数量少于请求数量,返回所有
if len(unevaluated) <= batch_size:
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
return unevaluated
# 随机选择指定数量
selected = random.sample(unevaluated, batch_size)
logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典,如果用户退出则返回 None
"""
print("\n" + "=" * 60)
print(f"人工评估 [{index}/{total}]")
print("=" * 60)
print(f"Situation: {expression.situation}")
print(f"Style: {expression.style}")
print("\n请评估该表达方式是否合适:")
print(" 输入 'y''yes''1' 表示合适(通过)")
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
print(" 输入 's''skip' 跳过当前项目")
while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ["q", "quit"]:
print("退出评估")
return None
if user_input in ["s", "skip"]:
print("跳过当前项目")
return "skip"
if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True
break
elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q/s)")
result = {
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual",
"evaluated_at": datetime.now().isoformat(),
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
def main():
"""主函数"""
logger.info("=" * 60)
logger.info("开始表达方式人工评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
manual_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
print("评估结果会自动保存到文件\n")
batch_size = 10
batch_count = 0
while True:
# 获取未评估的项目
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
if not expressions:
print("\n" + "=" * 60)
print("所有项目都已评估完成!")
print("=" * 60)
break
batch_count += 1
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
batch_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
# 用户退出
print("\n评估已中断")
if batch_results:
# 保存当前批次的结果
manual_results.extend(batch_results)
save_results(manual_results)
return
if manual_result == "skip":
# 跳过当前项目
continue
batch_results.append(manual_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
# 将当前批次结果添加到总结果中
manual_results.extend(batch_results)
# 保存结果
save_results(manual_results)
print(f"\n当前批次完成,已评估总数: {len(manual_results)}")
# 询问是否继续
while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ["y", "yes", "1", "", "继续"]:
break
elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束")
return
else:
print("输入无效,请重新输入 (y/n)")
# 关闭数据库连接
try:
db.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.warning(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from pathlib import Path
import ast
import re
PROJECT_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_EXCLUDE_PARTS = {
".git",
".venv",
"dashboard",
"docs",
"docs-src",
"locales",
}
HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]")
def should_skip(path: Path) -> bool:
return any(part in DEFAULT_EXCLUDE_PARTS for part in path.parts)
def iter_python_files(root: Path) -> list[Path]:
return sorted(path for path in root.rglob("*.py") if path.is_file() and not should_skip(path.relative_to(root)))
class CandidateExtractor(ast.NodeVisitor):
def __init__(self) -> None:
self._docstring_nodes: set[ast.AST] = set()
self.candidates: list[tuple[int, str]] = []
def visit_Module(self, node: ast.Module) -> None:
self._mark_docstring_node(node)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self._mark_docstring_node(node)
self.generic_visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self._mark_docstring_node(node)
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self._mark_docstring_node(node)
self.generic_visit(node)
def visit_Constant(self, node: ast.Constant) -> None:
if node in self._docstring_nodes:
return
if isinstance(node.value, str) and HAN_PATTERN.search(node.value):
self.candidates.append((node.lineno, node.value.strip()))
self.generic_visit(node)
def _mark_docstring_node(self, node: ast.Module | ast.ClassDef | ast.AsyncFunctionDef | ast.FunctionDef) -> None:
if not node.body:
return
first_stmt = node.body[0]
if isinstance(first_stmt, ast.Expr) and isinstance(first_stmt.value, ast.Constant):
if isinstance(first_stmt.value.value, str):
self._docstring_nodes.add(first_stmt.value)
def extract_candidates(file_path: Path) -> list[tuple[int, str]]:
source = file_path.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(file_path))
extractor = CandidateExtractor()
extractor.visit(tree)
return extractor.candidates
def main() -> int:
for file_path in iter_python_files(PROJECT_ROOT):
for lineno, text in extract_candidates(file_path):
print(f"{file_path.relative_to(PROJECT_ROOT)}:{lineno}: {text}")
return 0
if __name__ == "__main__":
raise SystemExit(main())

411
scripts/i18n_validate.py Normal file
View File

@@ -0,0 +1,411 @@
from __future__ import annotations
from pathlib import Path
from typing import Callable
import json
import re
import sys
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.i18n.exceptions import ( # noqa: E402
DuplicateTranslationKeyError,
InvalidTranslationFileError,
LocaleNotFoundError,
)
from src.common.i18n.loaders import ( # noqa: E402
DEFAULT_LOCALE,
PLURAL_CATEGORIES,
TranslationValue,
discover_locales,
get_locales_root,
load_locale_catalog,
validate_translation_value,
)
from src.common.i18n.loaders import extract_placeholders # noqa: E402
from src.common.prompt_i18n import ( # noqa: E402
discover_prompt_locales,
extract_prompt_placeholders,
get_prompts_root,
iter_prompt_files,
)
HAN_CHARACTER_PATTERN = re.compile(r"[\u3400-\u4DBF\u4E00-\u9FFF\uF900-\uFAFF]")
I18NEXT_PLACEHOLDER_PATTERN = re.compile(r"\{\{\s*([^\s,}]+)(?:\s*,[^}]*)?\s*\}\}")
DASHBOARD_DEFAULT_LOCALE = "zh"
def contains_han_characters(text: str) -> bool:
return HAN_CHARACTER_PATTERN.search(text) is not None
def extract_i18next_placeholders(template: str) -> set[str]:
placeholders: set[str] = set()
for match in I18NEXT_PLACEHOLDER_PATTERN.finditer(template):
placeholder_name = match.group(1)
placeholders.add(placeholder_name.split(".", maxsplit=1)[0].split("[", maxsplit=1)[0])
return placeholders
def iter_translation_strings(value: TranslationValue) -> list[str]:
if isinstance(value, str):
return [value]
return [value[category] for category in sorted(value.keys())]
def iter_shared_translation_strings(
source_value: TranslationValue, target_value: TranslationValue
) -> list[tuple[str, str]]:
if isinstance(source_value, str) or isinstance(target_value, str):
if isinstance(source_value, str) and isinstance(target_value, str):
return [(source_value, target_value)]
return []
shared_categories = sorted(set(source_value.keys()) & set(target_value.keys()))
return [(source_value[category], target_value[category]) for category in shared_categories]
def locale_requires_latin_only_validation(locale: str) -> bool:
normalized_locale = locale.lower()
return normalized_locale == "en" or normalized_locale.startswith("en-")
def validate_locale_content(
key: str,
source_value: TranslationValue,
target_value: TranslationValue,
locale: str,
errors: list[str],
locale_label: str | None = None,
) -> None:
resolved_locale_label = locale_label or locale
target_texts = iter_translation_strings(target_value)
if any(
source_text == target_text and contains_han_characters(source_text)
for source_text, target_text in iter_shared_translation_strings(source_value, target_value)
):
errors.append(
f"[{resolved_locale_label}] key '{key}' 直接保留了包含中文字符的 source 文案(仓库级校验策略),请提供目标语言翻译"
)
if locale_requires_latin_only_validation(locale) and any(contains_han_characters(text) for text in target_texts):
errors.append(f"[{resolved_locale_label}] key '{key}' 仍包含中文字符,请移除源语言残留后再提交")
def validate_translation_pair(
key: str,
source_value: TranslationValue,
target_value: TranslationValue,
locale: str,
errors: list[str],
placeholder_extractor: Callable[[str], set[str]] = extract_placeholders,
locale_label: str | None = None,
) -> None:
resolved_locale_label = locale_label or locale
if isinstance(source_value, str):
if not isinstance(target_value, str):
errors.append(
f"[{resolved_locale_label}] key '{key}' 与 source 的类型不一致source=string, target=plural"
)
return
if placeholder_extractor(source_value) != placeholder_extractor(target_value):
errors.append(f"[{resolved_locale_label}] key '{key}' 的占位符集合与 source 不一致")
return
if not isinstance(target_value, dict):
errors.append(f"[{resolved_locale_label}] key '{key}' 与 source 的类型不一致source=plural, target=string")
return
source_categories = set(source_value.keys())
target_categories = set(target_value.keys())
if source_categories != target_categories:
errors.append(
f"[{resolved_locale_label}] key '{key}' 的 plural category 不一致:"
f"source={sorted(source_categories)}, target={sorted(target_categories)}"
)
for category in sorted(source_categories & target_categories):
source_placeholders = placeholder_extractor(source_value[category])
target_placeholders = placeholder_extractor(target_value[category])
if source_placeholders != target_placeholders:
errors.append(
f"[{resolved_locale_label}] key '{key}' 的 plural category '{category}' 占位符集合与 source 不一致"
)
def get_dashboard_locales_root(locales_root: Path | None = None) -> Path:
if locales_root is not None:
return locales_root.resolve()
return (PROJECT_ROOT / "dashboard" / "src" / "i18n" / "locales").resolve()
def discover_dashboard_locales(locales_root: Path | None = None) -> list[str]:
root = get_dashboard_locales_root(locales_root)
if not root.exists():
return []
locale_names = [path.stem for path in root.glob("*.json") if path.is_file()]
return sorted(locale_names)
def is_plural_translation_node(value: object) -> bool:
if not isinstance(value, dict) or not value:
return False
return all(
isinstance(category, str) and category in PLURAL_CATEGORIES and isinstance(category_value, str)
for category, category_value in value.items()
)
def flatten_dashboard_translation_mapping(
value: dict[str, object],
file_path: Path,
translations: dict[str, TranslationValue],
parent_keys: list[str] | None = None,
) -> None:
current_parent_keys = parent_keys or []
if not value:
if current_parent_keys:
raise InvalidTranslationFileError(
f"{file_path} 中的 key '{'.'.join(current_parent_keys)}' 不能为空对象"
)
raise InvalidTranslationFileError(f"{file_path} 顶层不能为空对象")
for raw_key, raw_value in value.items():
if not isinstance(raw_key, str):
raise InvalidTranslationFileError(f"{file_path} 中存在非字符串 key")
normalized_key = raw_key.strip()
if not normalized_key:
raise InvalidTranslationFileError(f"{file_path} 中存在空字符串 key")
current_key_parts = [*current_parent_keys, normalized_key]
current_key = ".".join(current_key_parts)
if isinstance(raw_value, str):
if current_key in translations:
raise DuplicateTranslationKeyError(f"{file_path} 中存在重复 key: '{current_key}'")
translations[current_key] = raw_value
continue
if is_plural_translation_node(raw_value):
if current_key in translations:
raise DuplicateTranslationKeyError(f"{file_path} 中存在重复 key: '{current_key}'")
translations[current_key] = validate_translation_value(current_key, raw_value, file_path)
continue
if isinstance(raw_value, dict):
flatten_dashboard_translation_mapping(raw_value, file_path, translations, current_key_parts)
continue
raise InvalidTranslationFileError(f"{file_path} 中的 key '{current_key}' 必须是字符串或对象")
def load_dashboard_translation_file(file_path: Path) -> dict[str, TranslationValue]:
try:
raw_payload = json.loads(file_path.read_text(encoding="utf-8"))
except json.JSONDecodeError as exc:
raise InvalidTranslationFileError(f"{file_path} 不是合法 JSON: {exc}") from exc
if not isinstance(raw_payload, dict):
raise InvalidTranslationFileError(f"{file_path} 顶层必须是 JSON object")
translations: dict[str, TranslationValue] = {}
flatten_dashboard_translation_mapping(raw_payload, file_path, translations)
return translations
def load_dashboard_locale_catalog(
locale: str,
locales_root: Path | None = None,
) -> dict[str, TranslationValue]:
locale_file = get_dashboard_locales_root(locales_root) / f"{locale}.json"
if not locale_file.exists():
raise LocaleNotFoundError(f"未找到 locale 文件: {locale_file}")
return load_dashboard_translation_file(locale_file)
def validate_dashboard_json_locales(locales_root: Path | None = None) -> list[str]:
resolved_locales_root = get_dashboard_locales_root(locales_root)
locales = discover_dashboard_locales(resolved_locales_root)
errors: list[str] = []
if DASHBOARD_DEFAULT_LOCALE not in locales:
errors.append(f"[dashboard] 缺少默认 locale 文件: {DASHBOARD_DEFAULT_LOCALE}.json")
return errors
catalogs: dict[str, dict[str, TranslationValue]] = {}
for locale in locales:
try:
catalogs[locale] = load_dashboard_locale_catalog(locale, resolved_locales_root)
except Exception as exc:
errors.append(f"[dashboard:{locale}] 加载失败: {exc}")
source_catalog = catalogs.get(DASHBOARD_DEFAULT_LOCALE)
if source_catalog is None:
return errors
source_keys = set(source_catalog.keys())
for locale, catalog in catalogs.items():
if locale == DASHBOARD_DEFAULT_LOCALE:
continue
locale_label = f"dashboard:{locale}"
locale_keys = set(catalog.keys())
for key in sorted(source_keys - locale_keys):
errors.append(f"[{locale_label}] 缺少 key: {key}")
for key in sorted(locale_keys - source_keys):
errors.append(f"[{locale_label}] 存在多余 key: {key}")
for key in sorted(source_keys & locale_keys):
source_value = source_catalog[key]
target_value = catalog[key]
validate_translation_pair(
key,
source_value,
target_value,
locale,
errors,
placeholder_extractor=extract_i18next_placeholders,
locale_label=locale_label,
)
if isinstance(source_value, str) == isinstance(target_value, str):
validate_locale_content(key, source_value, target_value, locale, errors, locale_label=locale_label)
return errors
def validate_json_locales(locales_root: Path | None = None) -> list[str]:
resolved_locales_root = get_locales_root(locales_root)
locales = discover_locales(resolved_locales_root)
errors: list[str] = []
if DEFAULT_LOCALE not in locales:
errors.append(f"缺少默认 locale 目录: {DEFAULT_LOCALE}")
return errors
catalogs: dict[str, dict[str, TranslationValue]] = {}
for locale in locales:
try:
catalogs[locale] = load_locale_catalog(locale, resolved_locales_root)
except Exception as exc:
errors.append(f"[{locale}] 加载失败: {exc}")
source_catalog = catalogs.get(DEFAULT_LOCALE)
if source_catalog is None:
return errors
source_keys = set(source_catalog.keys())
for locale, catalog in catalogs.items():
if locale == DEFAULT_LOCALE:
continue
locale_keys = set(catalog.keys())
for key in sorted(source_keys - locale_keys):
errors.append(f"[{locale}] 缺少 key: {key}")
for key in sorted(locale_keys - source_keys):
errors.append(f"[{locale}] 存在多余 key: {key}")
for key in sorted(source_keys & locale_keys):
source_value = source_catalog[key]
target_value = catalog[key]
validate_translation_pair(key, source_value, target_value, locale, errors)
if isinstance(source_value, str) == isinstance(target_value, str):
validate_locale_content(key, source_value, target_value, locale, errors)
return errors
def build_prompt_catalog(locale_dir: Path) -> dict[Path, Path]:
return {path.relative_to(locale_dir): path for path in iter_prompt_files(locale_dir)}
def validate_prompt_templates(prompts_root: Path | None = None) -> tuple[list[str], list[str]]:
resolved_prompts_root = get_prompts_root(prompts_root)
prompt_locales = set(discover_prompt_locales(resolved_prompts_root))
known_locales = [locale for locale in discover_locales(get_locales_root()) if locale != DEFAULT_LOCALE]
errors: list[str] = []
warnings: list[str] = []
if DEFAULT_LOCALE not in prompt_locales:
errors.append(f"缺少默认 Prompt locale 目录: {DEFAULT_LOCALE}")
return errors, warnings
source_dir = resolved_prompts_root / DEFAULT_LOCALE
source_files = build_prompt_catalog(source_dir)
source_relative_paths = set(source_files.keys())
for locale in known_locales:
locale_dir = resolved_prompts_root / locale
if not locale_dir.exists():
warnings.append(f"[prompt:{locale}] 缺少 locale 目录,运行时将回退到 {DEFAULT_LOCALE}")
continue
locale_files = build_prompt_catalog(locale_dir)
locale_relative_paths = set(locale_files.keys())
for relative_path in sorted(source_relative_paths - locale_relative_paths):
warnings.append(f"[prompt:{locale}] 缺少模板: {relative_path.as_posix()},运行时将回退到 {DEFAULT_LOCALE}")
for relative_path in sorted(locale_relative_paths - source_relative_paths):
warnings.append(f"[prompt:{locale}] 存在额外模板: {relative_path.as_posix()}")
for relative_path in sorted(source_relative_paths & locale_relative_paths):
source_text = source_files[relative_path].read_text(encoding="utf-8")
locale_text = locale_files[relative_path].read_text(encoding="utf-8")
source_placeholders = extract_prompt_placeholders(source_text)
locale_placeholders = extract_prompt_placeholders(locale_text)
if source_placeholders != locale_placeholders:
errors.append(
"[prompt:{locale}] 模板 '{path}' 的占位符集合与 source 不一致:"
"source={source_placeholders}, target={target_placeholders}".format(
locale=locale,
path=relative_path.as_posix(),
source_placeholders=sorted(source_placeholders),
target_placeholders=sorted(locale_placeholders),
)
)
if source_text == locale_text:
warnings.append(f"[prompt:{locale}] 模板 '{relative_path.as_posix()}' 与 source 完全相同,可能尚未翻译")
return errors, warnings
def _print_warnings(warnings: list[str]) -> None:
if not warnings:
return
print(f"warnings ({len(warnings)}):")
for warning in warnings[:10]:
print(f" - {warning}")
if len(warnings) > 10:
print(f" - ... 另外还有 {len(warnings) - 10} 条 warning")
def main() -> int:
errors = validate_json_locales()
errors.extend(validate_dashboard_json_locales())
prompt_errors, prompt_warnings = validate_prompt_templates()
errors.extend(prompt_errors)
if errors:
print("i18n validation failed:")
for error in errors:
print(f" - {error}")
_print_warnings(prompt_warnings)
return 1
print("i18n validation passed.")
_print_warnings(prompt_warnings)
return 0
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,40 @@
import tomlkit
def generate_requirements(pyproject_path="pyproject.toml", output_path="requirements.txt"):
try:
# 读取 pyproject.toml 文件
with open(pyproject_path, "r", encoding="utf-8") as file:
pyproject_data = tomlkit.load(file)
# 获取 pyproject.toml 中的 dependencies 列表
pyproject_dependencies = pyproject_data.get("project", {}).get("dependencies", [])
if not pyproject_dependencies:
print("未找到 dependencies 部分,无法生成 requirements.txt")
return
# 读取 requirements.txt 文件
try:
with open(output_path, "r", encoding="utf-8") as file:
requirements = {line.strip() for line in file if line.strip()}
except FileNotFoundError:
requirements = set()
if extra_dependencies := requirements - set(pyproject_dependencies):
print("警告: 以下依赖项存在于 requirements.txt 中,但未在 pyproject.toml 中找到:")
for dep in extra_dependencies:
print(f" - {dep}")
# 写入更新后的 requirements.txt 文件
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(pyproject_dependencies))
print(f"requirements.txt 文件已生成: {output_path}")
except FileNotFoundError:
print(f"未找到 {pyproject_path} 文件,请检查路径是否正确。")
except Exception as e:
print(f"发生错误: {e}")
if __name__ == "__main__":
generate_requirements()

1132
scripts/mmipkg_tool.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

973
scripts/run.sh Normal file
View File

@@ -0,0 +1,973 @@
#!/bin/bash
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
# 适用于macOS/Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.5-refactor"
LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址
GITHUB_REPO="https://ghfast.top/https://github.com"
# 颜色输出
GREEN="\e[32m"
RED="\e[31m"
RESET="\e[0m"
# 需要的基本软件包(兼容 Bash 3避免使用关联数组
REQUIRED_PACKAGES_COMMON="git sudo python3 curl gnupg"
REQUIRED_PACKAGES_DEBIAN="python3-venv python3-pip build-essential"
REQUIRED_PACKAGES_UBUNTU="python3-venv python3-pip build-essential"
REQUIRED_PACKAGES_CENTOS="epel-release python3-pip python3-devel gcc gcc-c++ make"
REQUIRED_PACKAGES_ARCH="python-virtualenv python-pip base-devel"
REQUIRED_PACKAGES_MACOS="git gnupg python"
# 服务名称
SERVICE_NAME="maicore"
SERVICE_NAME_WEB="maicore-web"
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
SERVICE_USER="${SUDO_USER:-$USER}"
SERVICE_HOME="$(eval echo "~${SERVICE_USER}" 2>/dev/null)"
if [[ -z "$SERVICE_HOME" || "$SERVICE_HOME" == "~${SERVICE_USER}" ]]; then
SERVICE_HOME="$HOME"
fi
IS_MACOS=false
[[ "$(uname -s)" == "Darwin" ]] && IS_MACOS=true
INSTALL_CONF="/etc/maicore_install.conf"
# 默认项目目录
DEFAULT_INSTALL_DIR="/opt/maicore"
if [[ "$IS_MACOS" == true ]]; then
DEFAULT_INSTALL_DIR="${SERVICE_HOME}/maicore"
INSTALL_CONF="${SERVICE_HOME}/.config/maicore/maicore_install.conf"
fi
LAUNCHD_DOMAIN=""
LAUNCHD_AGENT_DIR=""
LAUNCHD_LABEL_MAIN="com.maicore.${SERVICE_NAME}"
LAUNCHD_LABEL_NBADAPTER="com.maicore.${SERVICE_NAME_NBADAPTER}"
LAUNCHD_PLIST_MAIN=""
LAUNCHD_PLIST_NBADAPTER=""
if [[ "$IS_MACOS" == true ]]; then
SERVICE_UID="$(id -u "${SERVICE_USER}" 2>/dev/null || id -u)"
LAUNCHD_DOMAIN="gui/${SERVICE_UID}"
LAUNCHD_AGENT_DIR="${SERVICE_HOME}/Library/LaunchAgents"
LAUNCHD_PLIST_MAIN="${LAUNCHD_AGENT_DIR}/${LAUNCHD_LABEL_MAIN}.plist"
LAUNCHD_PLIST_NBADAPTER="${LAUNCHD_AGENT_DIR}/${LAUNCHD_LABEL_NBADAPTER}.plist"
fi
get_required_packages() {
local distro="$1"
case "$distro" in
debian)
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_DEBIAN}"
;;
ubuntu)
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_UBUNTU}"
;;
centos)
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_CENTOS}"
;;
arch)
echo "${REQUIRED_PACKAGES_COMMON} ${REQUIRED_PACKAGES_ARCH}"
;;
macos)
echo "${REQUIRED_PACKAGES_MACOS}"
;;
*)
echo "${REQUIRED_PACKAGES_COMMON}"
;;
esac
}
IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false
resolve_brew_bin() {
local brew_bin
brew_bin="$(command -v brew)"
[[ -z "$brew_bin" && -x /opt/homebrew/bin/brew ]] && brew_bin="/opt/homebrew/bin/brew"
[[ -z "$brew_bin" && -x /usr/local/bin/brew ]] && brew_bin="/usr/local/bin/brew"
[[ -n "$brew_bin" ]] && echo "$brew_bin"
}
run_brew() {
local brew_bin
brew_bin="$(resolve_brew_bin)"
[[ -z "$brew_bin" ]] && return 1
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
sudo -u "${SUDO_USER}" "${brew_bin}" "$@"
else
"${brew_bin}" "$@"
fi
}
run_launchctl() {
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
sudo -u "${SUDO_USER}" launchctl "$@"
else
launchctl "$@"
fi
}
ensure_writable_parent() {
local path="$1"
local parent
parent="$(dirname "$path")"
mkdir -p "$parent"
if [[ "$IS_MACOS" == true && "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" ]]; then
chown "${SUDO_USER}" "$parent" 2>/dev/null || true
fi
}
save_install_info() {
ensure_writable_parent "$INSTALL_CONF"
cat > "$INSTALL_CONF" <<EOF
INSTALLER_VERSION=${INSTALLER_VERSION}
INSTALL_DIR=${INSTALL_DIR}
BRANCH=${BRANCH}
EOF
}
compute_md5() {
local file_path="$1"
if command -v md5sum &>/dev/null; then
md5sum "$file_path" | awk '{print $1}'
elif command -v md5 &>/dev/null; then
md5 -q "$file_path"
else
return 1
fi
}
launchd_label_for_service() {
local service="$1"
case "$service" in
${SERVICE_NAME})
echo "$LAUNCHD_LABEL_MAIN"
;;
${SERVICE_NAME_NBADAPTER})
echo "$LAUNCHD_LABEL_NBADAPTER"
;;
*)
return 1
;;
esac
}
launchd_plist_for_service() {
local service="$1"
case "$service" in
${SERVICE_NAME})
echo "$LAUNCHD_PLIST_MAIN"
;;
${SERVICE_NAME_NBADAPTER})
echo "$LAUNCHD_PLIST_NBADAPTER"
;;
*)
return 1
;;
esac
}
is_launchd_service_loaded() {
local service="$1"
local label
label="$(launchd_label_for_service "$service")" || return 1
run_launchctl print "${LAUNCHD_DOMAIN}/${label}" &>/dev/null
}
start_service() {
local service="$1"
if [[ "$IS_MACOS" == true ]]; then
local label
local plist
label="$(launchd_label_for_service "$service")" || return 1
plist="$(launchd_plist_for_service "$service")" || return 1
if [[ ! -f "$plist" && -d "${INSTALL_DIR}/MaiBot" ]]; then
create_launchd_services
fi
if [[ ! -f "$plist" ]]; then
echo -e "${RED}未找到服务配置文件:${plist}${RESET}"
return 1
fi
if is_launchd_service_loaded "$service"; then
run_launchctl kickstart -k "${LAUNCHD_DOMAIN}/${label}"
else
run_launchctl bootstrap "${LAUNCHD_DOMAIN}" "$plist"
fi
else
systemctl start "$service"
fi
}
stop_service() {
local service="$1"
if [[ "$IS_MACOS" == true ]]; then
local label
label="$(launchd_label_for_service "$service")" || return 1
if is_launchd_service_loaded "$service"; then
run_launchctl bootout "${LAUNCHD_DOMAIN}/${label}"
fi
else
systemctl stop "$service"
fi
}
restart_service() {
local service="$1"
if [[ "$IS_MACOS" == true ]]; then
stop_service "$service"
start_service "$service"
else
systemctl restart "$service"
fi
}
create_launchd_services() {
mkdir -p "${LAUNCHD_AGENT_DIR}"
mkdir -p "${INSTALL_DIR}/logs"
cat > "${LAUNCHD_PLIST_MAIN}" <<EOF
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>${LAUNCHD_LABEL_MAIN}</string>
<key>ProgramArguments</key>
<array>
<string>${INSTALL_DIR}/venv/bin/python3</string>
<string>bot.py</string>
</array>
<key>WorkingDirectory</key>
<string>${INSTALL_DIR}/MaiBot</string>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>${INSTALL_DIR}/logs/${SERVICE_NAME}.log</string>
<key>StandardErrorPath</key>
<string>${INSTALL_DIR}/logs/${SERVICE_NAME}.error.log</string>
</dict>
</plist>
EOF
cat > "${LAUNCHD_PLIST_NBADAPTER}" <<EOF
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>${LAUNCHD_LABEL_NBADAPTER}</string>
<key>ProgramArguments</key>
<array>
<string>${INSTALL_DIR}/venv/bin/python3</string>
<string>main.py</string>
</array>
<key>WorkingDirectory</key>
<string>${INSTALL_DIR}/MaiBot-Napcat-Adapter</string>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>${INSTALL_DIR}/logs/${SERVICE_NAME_NBADAPTER}.log</string>
<key>StandardErrorPath</key>
<string>${INSTALL_DIR}/logs/${SERVICE_NAME_NBADAPTER}.error.log</string>
</dict>
</plist>
EOF
if [[ "$(id -u)" -eq 0 && -n "${SUDO_USER:-}" && "${SUDO_USER}" != "root" ]]; then
chown "${SUDO_USER}" "${LAUNCHD_PLIST_MAIN}" "${LAUNCHD_PLIST_NBADAPTER}" "${LAUNCHD_AGENT_DIR}" 2>/dev/null || true
fi
}
# 检查是否已安装
check_installed() {
if [[ "$IS_MACOS" == true ]]; then
[[ -f "$INSTALL_CONF" ]]
else
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
fi
}
# 加载安装信息
load_install_info() {
if [[ -f "$INSTALL_CONF" ]]; then
source "$INSTALL_CONF"
else
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
BRANCH="refactor"
fi
}
# 显示管理菜单
show_menu() {
while true; do
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
"1" "启动MaiCore" \
"2" "停止MaiCore" \
"3" "重启MaiCore" \
"4" "启动NapCat Adapter" \
"5" "停止NapCat Adapter" \
"6" "重启NapCat Adapter" \
"7" "拉取最新MaiCore仓库" \
"8" "切换分支" \
"9" "退出" 3>&1 1>&2 2>&3)
[[ $? -ne 0 ]] && exit 0
case "$choice" in
1)
start_service "${SERVICE_NAME}"
whiptail --msgbox "✅MaiCore已启动" 10 60
;;
2)
stop_service "${SERVICE_NAME}"
whiptail --msgbox "🛑MaiCore已停止" 10 60
;;
3)
restart_service "${SERVICE_NAME}"
whiptail --msgbox "🔄MaiCore已重启" 10 60
;;
4)
start_service "${SERVICE_NAME_NBADAPTER}"
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
;;
5)
stop_service "${SERVICE_NAME_NBADAPTER}"
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
;;
6)
restart_service "${SERVICE_NAME_NBADAPTER}"
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
;;
7)
update_dependencies
;;
8)
switch_branch
;;
9)
exit 0
;;
*)
whiptail --msgbox "无效选项!" 10 60
;;
esac
done
}
# 更新依赖
update_dependencies() {
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
stop_service "${SERVICE_NAME}"
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git pull origin "${BRANCH}"; then
whiptail --msgbox "🚫 代码更新失败!" 10 60
return 1
fi
source "${INSTALL_DIR}/venv/bin/activate"
if ! pip install -r requirements.txt; then
whiptail --msgbox "🚫 依赖安装失败!" 10 60
deactivate
return 1
fi
deactivate
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
}
# 切换分支
switch_branch() {
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
[[ -z "$new_branch" ]] && {
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
return 1
}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
return 1
fi
if ! git checkout "${new_branch}"; then
whiptail --msgbox "🚫 分支切换失败!" 10 60
return 1
fi
if ! git pull origin "${new_branch}"; then
whiptail --msgbox "🚫 代码拉取失败!" 10 60
return 1
fi
stop_service "${SERVICE_NAME}"
source "${INSTALL_DIR}/venv/bin/activate"
pip install -r requirements.txt
deactivate
BRANCH="${new_branch}"
save_install_info
check_eula
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} " 10 60
}
check_eula() {
# 首先计算当前EULA的MD5值
current_md5=$(compute_md5 "${INSTALL_DIR}/MaiBot/EULA.md")
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(compute_md5 "${INSTALL_DIR}/MaiBot/PRIVACY.md")
# 如果当前的md5值为空则直接返回
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
fi
# 检查eula.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5=$(cat "${INSTALL_DIR}/MaiBot/eula.confirmed")
else
confirmed_md5=""
fi
# 检查privacy.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5_privacy=$(cat "${INSTALL_DIR}/MaiBot/privacy.confirmed")
else
confirmed_md5_privacy=""
fi
# 如果EULA或隐私条款有更新提示用户重新确认
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70
if [[ $? -eq 0 ]]; then
echo -n "$current_md5" > "${INSTALL_DIR}/MaiBot/eula.confirmed"
echo -n "$current_md5_privacy" > "${INSTALL_DIR}/MaiBot/privacy.confirmed"
else
exit 1
fi
fi
}
# 测速并选择PyPI源仅当阿里云更快时使用阿里云
measure_url_latency() {
local url="$1"
local latency
latency=$(curl -sS -o /dev/null -w "%{time_total}" --connect-timeout 3 --max-time 8 "$url" 2>/dev/null)
if [[ $? -eq 0 && "$latency" =~ ^[0-9]+(\.[0-9]+)?$ ]]; then
echo "$latency"
return 0
else
echo "999999"
return 1
fi
}
resolve_default_pypi_index_url() {
local default_url=""
if [[ -n "${PIP_INDEX_URL:-}" ]]; then
default_url="$PIP_INDEX_URL"
elif [[ -n "${UV_INDEX_URL:-}" ]]; then
default_url="$UV_INDEX_URL"
elif command -v pip &>/dev/null; then
default_url=$(pip config get global.index-url 2>/dev/null | head -n 1)
if [[ -z "$default_url" ]]; then
default_url=$(pip config get install.index-url 2>/dev/null | head -n 1)
fi
fi
if [[ -z "$default_url" ]]; then
default_url="https://pypi.org/simple"
fi
echo "$default_url"
}
select_pypi_index_url() {
local default_url
local aliyun_url="https://mirrors.aliyun.com/pypi/simple"
local default_latency
local aliyun_latency
local default_status
local aliyun_status
default_url=$(resolve_default_pypi_index_url)
default_latency=$(measure_url_latency "$default_url")
default_status=$?
aliyun_latency=$(measure_url_latency "$aliyun_url")
aliyun_status=$?
if [[ $aliyun_status -eq 0 && $default_status -ne 0 ]]; then
PYPI_INDEX_URL="$aliyun_url"
PYPI_INDEX_NAME="阿里云镜像(默认源测速失败)"
UV_PIP_INDEX_OPTION=(-i "$aliyun_url")
echo -e "${RED}默认源测速失败,已选择${PYPI_INDEX_NAME}${PYPI_INDEX_URL}${RESET}"
return
fi
if [[ $aliyun_status -ne 0 && $default_status -eq 0 ]]; then
PYPI_INDEX_URL="$default_url"
PYPI_INDEX_NAME="默认源(阿里云测速失败)"
UV_PIP_INDEX_OPTION=()
echo -e "${RED}阿里云测速失败,已选择${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
return
fi
if [[ $aliyun_status -ne 0 && $default_status -ne 0 ]]; then
PYPI_INDEX_URL="$default_url"
PYPI_INDEX_NAME="默认源(双源测速失败)"
UV_PIP_INDEX_OPTION=()
echo -e "${RED}默认源和阿里云测速均失败,回退到${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
return
fi
if awk "BEGIN {exit !(${aliyun_latency} < ${default_latency})}"; then
PYPI_INDEX_URL="$aliyun_url"
PYPI_INDEX_NAME="阿里云镜像"
UV_PIP_INDEX_OPTION=(-i "$aliyun_url")
else
PYPI_INDEX_URL="$default_url"
PYPI_INDEX_NAME="默认源"
UV_PIP_INDEX_OPTION=()
fi
if [[ ${#UV_PIP_INDEX_OPTION[@]} -gt 0 ]]; then
echo -e "${GREEN}已选择${PYPI_INDEX_NAME}${PYPI_INDEX_URL}${RESET}"
else
echo -e "${GREEN}已选择${PYPI_INDEX_NAME}:不显式指定 -i 参数${RESET}"
fi
}
# ----------- 主安装流程 -----------
run_installation() {
# 1/6: 检测是否安装 whiptail
if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
if command -v apt-get &>/dev/null; then
apt-get update && apt-get install -y whiptail
elif command -v pacman &>/dev/null; then
pacman -Syu --noconfirm whiptail
elif command -v yum &>/dev/null; then
yum install -y whiptail
elif command -v brew &>/dev/null || [[ -x /opt/homebrew/bin/brew ]] || [[ -x /usr/local/bin/brew ]]; then
run_brew install newt
# 确保当前 shell 能找到 Homebrew 安装的 whiptail。
[[ -x /opt/homebrew/bin/whiptail ]] && export PATH="/opt/homebrew/bin:${PATH}"
[[ -x /usr/local/bin/whiptail ]] && export PATH="/usr/local/bin:${PATH}"
else
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
exit 1
fi
if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[Error] whiptail 安装失败或不可用,请手动安装后重试。${RESET}"
exit 1
fi
fi
whiptail --title " 提示" --msgbox "如果您没有特殊需求请优先使用docker方式部署。" 10 60
# 协议确认
if ! (whiptail --title " [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议" 12 70); then
exit 1
fi
# 欢迎信息
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore将自动进入安装流程安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段代码可能随时更改\n文档未完善有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险请自行了解谨慎使用\n由于持续迭代可能存在一些已知或未知的bug\n由于开发中可能消耗较多token\n\n本脚本可能更新不及时如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
# 系统检查
check_system() {
if [[ "$IS_MACOS" == true ]]; then
ID="macos"
VERSION_ID="$(sw_vers -productVersion 2>/dev/null)"
PRETTY_NAME="macOS ${VERSION_ID}"
return
fi
if [[ "$(id -u)" -ne 0 ]]; then
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
exit 1
fi
if [[ -f /etc/os-release ]]; then
source /etc/os-release
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
return
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
return
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
return
elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
return
else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
exit 1
fi
else
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
exit 1
fi
}
check_system
# 设置包管理器
case "$ID" in
debian|ubuntu)
PKG_MANAGER="apt"
;;
centos)
PKG_MANAGER="yum"
;;
arch)
# 添加arch包管理器
PKG_MANAGER="pacman"
;;
macos)
PKG_MANAGER="brew"
;;
esac
# 检查NapCat
check_napcat() {
if command -v napcat &>/dev/null; then
NAPCAT_INSTALLED=true
else
NAPCAT_INSTALLED=false
fi
}
check_napcat
# 安装必要软件包
install_packages() {
missing_packages=()
# 检查 common 及当前系统专属依赖
for package in $(get_required_packages "$ID"); do
case "$PKG_MANAGER" in
apt)
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
;;
yum)
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
;;
pacman)
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
;;
brew)
case "$package" in
git)
command -v git &>/dev/null || missing_packages+=("$package")
;;
gnupg)
command -v gpg &>/dev/null || missing_packages+=("$package")
;;
python)
command -v python3 &>/dev/null || missing_packages+=("$package")
;;
*)
run_brew list --formula "$package" &>/dev/null || missing_packages+=("$package")
;;
esac
;;
esac
done
if [[ ${#missing_packages[@]} -gt 0 ]]; then
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装" 10 60
if [[ $? -eq 0 ]]; then
IS_INSTALL_DEPENDENCIES=true
else
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续" 10 60 || exit 1
fi
fi
}
install_packages
# 安装NapCat
install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && {
IS_INSTALL_NAPCAT=true
}
}
# 仅在 Linux 非 Arch 系统上安装 NapCatmacOS 仅支持远程 NapCat。
if [[ "$ID" == "macos" ]]; then
whiptail --title "⚠️ NapCat 安装提示" --msgbox "当前为 macOS暂不支持自动安装 NapCat。\n如需使用 NapCat请配置远程实例后再连接。 " 10 60
elif [[ "$ID" != "arch" ]]; then
install_napcat
fi
# Python版本检查
check_python() {
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
exit 1
fi
}
# 如果没安装python则不检查python版本
if command -v python3 &>/dev/null; then
check_python
fi
# 选择分支
choose_branch() {
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
"main" "稳定版本(推荐)" ON \
"dev" "开发版(不知道什么意思就别选)" OFF \
"classical" "经典版0.6.0以前的版本)" OFF \
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 操作取消!" 10 60
exit 1
fi
if [[ "$BRANCH" == "custom" ]]; then
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 输入取消!" 10 60
exit 1
fi
if [[ -z "$BRANCH" ]]; then
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
exit 1
fi
fi
}
choose_branch
# 选择安装路径
choose_install_dir() {
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
[[ -z "$INSTALL_DIR" ]] && {
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
}
}
choose_install_dir
# 确认安装
confirm_install() {
local confirm_msg="请确认以下更改:\n\n"
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
}
confirm_install
# 开始安装
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
case "$PKG_MANAGER" in
apt)
apt update && apt install -y "${missing_packages[@]}"
;;
yum)
yum install -y "${missing_packages[@]}" --nobest
;;
pacman)
pacman -S --noconfirm "${missing_packages[@]}"
;;
brew)
run_brew update && run_brew install "${missing_packages[@]}"
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
fi
echo -e "${GREEN}创建安装目录...${RESET}"
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR" || exit 1
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
python3 -m venv venv
source venv/bin/activate
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
echo -e "${RED}克隆MaiCore仓库失败${RESET}"
exit 1
}
echo -e "${GREEN}A_Memorix 已内置到源码,无需初始化子模块。${RESET}"
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}安装Python依赖...${RESET}"
select_pypi_index_url
pip install -r MaiBot/requirements.txt
cd MaiBot
pip install uv
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -r requirements.txt
cd ..
echo -e "${GREEN}安装maim_message依赖...${RESET}"
cd maim_message
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -e .
cd ..
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
cd MaiBot-Napcat-Adapter
uv pip install "${UV_PIP_INDEX_OPTION[@]}" -r requirements.txt
cd ..
echo -e "${GREEN}同意协议...${RESET}"
# 首先计算当前EULA的MD5值
current_md5=$(compute_md5 "MaiBot/EULA.md")
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(compute_md5 "MaiBot/PRIVACY.md")
echo -n "$current_md5" > MaiBot/eula.confirmed
echo -n "$current_md5_privacy" > MaiBot/privacy.confirmed
if [[ "$IS_MACOS" == true ]]; then
echo -e "${GREEN}创建 launchctl 服务...${RESET}"
create_launchd_services
stop_service "${SERVICE_NAME}" >/dev/null 2>&1 || true
stop_service "${SERVICE_NAME_NBADAPTER}" >/dev/null 2>&1 || true
else
echo -e "${GREEN}创建系统服务...${RESET}"
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
[Unit]
Description=MaiCore
After=network.target ${SERVICE_NAME_NBADAPTER}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
# [Unit]
# Description=MaiCore WebUI
# After=network.target ${SERVICE_NAME}.service
# [Service]
# Type=simple
# WorkingDirectory=${INSTALL_DIR}/MaiBot
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
# Restart=always
# RestartSec=10s
# [Install]
# WantedBy=multi-user.target
# EOF
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
[Unit]
Description=MaiBot Napcat Adapter
After=network.target mongod.service ${SERVICE_NAME}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
fi
# 保存安装信息
save_install_info
if [[ "$IS_MACOS" == true ]]; then
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成\n已创建 launchctl 服务:${LAUNCHD_LABEL_MAIN}${LAUNCHD_LABEL_NBADAPTER}\n\n首次加载launchctl bootstrap ${LAUNCHD_DOMAIN} ${LAUNCHD_PLIST_MAIN}\n重启服务launchctl kickstart -k ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN}\n查看状态launchctl print ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN}" 14 100
else
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成\n已创建系统服务${SERVICE_NAME}${SERVICE_NAME_WEB}${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务\n启动服务systemctl start ${SERVICE_NAME}\n查看状态systemctl status ${SERVICE_NAME}" 14 60
fi
}
# ----------- 主执行流程 -----------
# Linux 仍需 rootmacOS 使用用户级 launchctl无需 root
if [[ "$IS_MACOS" == true && $(id -u) -eq 0 ]]; then
echo -e "${RED}macOS 请勿使用 root/sudo 运行此脚本,请直接以当前登录用户执行。${RESET}"
exit 1
fi
if [[ "$IS_MACOS" != true && $(id -u) -ne 0 ]]; then
echo -e "${RED}请使用root用户运行此脚本${RESET}"
exit 1
fi
# 如果已安装显示菜单,并检查协议是否更新
if check_installed; then
load_install_info
check_eula
show_menu
else
run_installation
# 安装完成后询问是否启动
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务" 10 60; then
start_service "${SERVICE_NAME}"
if [[ "$IS_MACOS" == true ]]; then
whiptail --msgbox "✅ 服务已启动!\n使用 launchctl print ${LAUNCHD_DOMAIN}/${LAUNCHD_LABEL_MAIN} 查看状态" 10 80
else
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
fi
fi
fi

View File

@@ -0,0 +1,25 @@
from __future__ import annotations
import asyncio
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.A_memorix.host_service import a_memorix_host_service
from src.webui.webui_server import get_webui_server
async def main() -> None:
server = get_webui_server()
await a_memorix_host_service.start()
try:
await server.start()
finally:
await a_memorix_host_service.stop()
if __name__ == "__main__":
asyncio.run(main())

51
scripts/run_lpmm.sh Normal file
View File

@@ -0,0 +1,51 @@
#!/bin/bash
# ==============================================
# Environment Initialization
# ==============================================
# Step 1: Locate project root directory
SCRIPTS_DIR="scripts"
SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd)
PROJECT_ROOT=$(cd "$SCRIPT_DIR/.." && pwd)
# Step 2: Verify scripts directory exists
if [ ! -d "$PROJECT_ROOT/$SCRIPTS_DIR" ]; then
echo "❌ Error: scripts directory not found in project root" >&2
echo "Current path: $PROJECT_ROOT" >&2
exit 1
fi
# Step 3: Set up Python environment
export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH"
cd "$PROJECT_ROOT" || {
echo "❌ Failed to cd to project root: $PROJECT_ROOT" >&2
exit 1
}
# Debug info
echo "============================"
echo "Project Root: $PROJECT_ROOT"
echo "Python Path: $PYTHONPATH"
echo "Working Dir: $(pwd)"
echo "============================"
# ==============================================
# Python Script Execution
# ==============================================
run_python_script() {
local script_name=$1
echo "🔄 Running $script_name"
if ! python3 "$SCRIPTS_DIR/$script_name"; then
echo "$script_name failed" >&2
exit 1
fi
}
# Execute scripts in order
run_python_script "raw_data_preprocessor.py"
run_python_script "info_extraction.py"
run_python_script "import_openie.py"
echo "✅ All scripts completed successfully"

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env bash
set -euo pipefail
MODE="${1:-pull}"
REMOTE_URL="${2:-https://github.com/A-Dawn/A_memorix.git}"
BRANCH="${3:-MaiBot_branch}"
PREFIX="src/A_memorix"
case "$MODE" in
add)
git subtree add --prefix="$PREFIX" "$REMOTE_URL" "$BRANCH" --squash
;;
pull)
git subtree pull --prefix="$PREFIX" "$REMOTE_URL" "$BRANCH" --squash
;;
*)
echo "Usage: $0 [add|pull] [remote_url] [branch]" >&2
exit 2
;;
esac

View File

@@ -0,0 +1,459 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from dataclasses import dataclass
from typing import Optional, Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
try:
from maim_message import ChatStream, UserInfo, GroupInfo
except Exception:
@dataclass
class ChatStream:
stream_id: str
platform: str
user_info: UserInfo
group_info: GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
memory_retrieval_module._process_single_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
"""创建一个测试用的 ChatStream 对象"""
user_info = UserInfo(
platform="test",
user_id="test_user",
user_nickname="测试用户",
)
group_info = GroupInfo(
platform="test",
group_id="test_group",
group_name="测试群组",
)
return ChatStream(
stream_id=chat_id,
platform="test",
user_info=user_info,
group_info=group_info,
)
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
obs_str = str(obs)[:200]
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
async def test_memory_retrieval(
question: str,
chat_id: str = "test_memory_retrieval",
context: str = "",
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
"found_answer": found_answer,
"answer": answer,
"is_timeout": is_timeout,
"elapsed_time": elapsed_time,
"thinking_steps": thinking_steps,
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="测试记忆检索功能。可以输入一个问题脚本会使用记忆检索的逻辑进行检索并记录迭代信息、时间和token总消耗。"
)
parser.add_argument(
"--chat-id",
default="test_memory_retrieval",
help="测试用的聊天ID默认: test_memory_retrieval",
)
parser.add_argument(
"--context",
default="",
help="上下文信息(可选)",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
if max_iterations_input:
try:
max_iterations = int(max_iterations_input)
if max_iterations <= 0:
print("警告: 迭代次数必须大于0将使用配置默认值")
max_iterations = None
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
test_memory_retrieval(
question=question,
chat_id=args.chat_id,
context=args.context,
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,845 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ToolCallCase:
"""Tool call 参数测试用例。"""
name: str
description: str
tool_definition: Dict[str, Any]
expected_arguments: Dict[str, Any]
@property
def tool_name(self) -> str:
"""返回工具名称。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
return str(function_definition.get("name", "") or "")
return str(self.tool_definition.get("name", "") or "")
@property
def parameters_schema(self) -> Dict[str, Any]:
"""返回参数 Schema。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
parameters = function_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
parameters = self.tool_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
def build_messages(self) -> List[Dict[str, Any]]:
"""构造测试消息。"""
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
system_prompt = (
"你正在执行严格的工具调用参数兼容性测试。"
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
)
user_prompt = (
f"请立刻调用工具 `{self.tool_name}`。\n"
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
"不要输出任何解释文本,只返回工具调用。\n"
f"{expected_json}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ProbeResult:
"""单次测试结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str]
warnings: List[str]
response_text: str
reasoning_text: str
tool_calls: List[Dict[str, Any]]
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切到 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool 定义。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_default_cases() -> List[ToolCallCase]:
"""构造默认测试用例。"""
simple_expected_arguments = {
"request_id": "probe-simple-001",
"count": 7,
"enabled": True,
"mode": "strict",
"ratio": 2.5,
}
simple_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"count": {"type": "integer", "description": "数量"},
"enabled": {"type": "boolean", "description": "是否启用"},
"mode": {
"type": "string",
"description": "模式",
"enum": ["strict", "loose"],
},
"ratio": {"type": "number", "description": "比例"},
},
"required": ["request_id", "count", "enabled", "mode", "ratio"],
"additionalProperties": False,
}
nested_expected_arguments = {
"request_id": "probe-nested-001",
"notify": False,
"profile": {
"channel": "stable",
"priority": 2,
},
"tags": ["alpha", "beta", "gamma"],
"items": [
{"count": 2, "name": "apple"},
{"count": 5, "name": "banana"},
],
}
nested_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"notify": {"type": "boolean", "description": "是否通知"},
"profile": {
"type": "object",
"description": "配置对象",
"properties": {
"channel": {"type": "string", "description": "渠道"},
"priority": {"type": "integer", "description": "优先级"},
},
"required": ["channel", "priority"],
"additionalProperties": False,
},
"tags": {
"type": "array",
"description": "标签列表",
"items": {"type": "string"},
},
"items": {
"type": "array",
"description": "条目列表",
"items": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "数量"},
"name": {"type": "string", "description": "名称"},
},
"required": ["count", "name"],
"additionalProperties": False,
},
},
},
"required": ["request_id", "notify", "profile", "tags", "items"],
"additionalProperties": False,
}
return [
ToolCallCase(
name="simple",
description="标量参数类型校验",
tool_definition=_build_function_tool(
name="record_simple_probe",
description="记录简单参数探测结果",
parameters=simple_parameters,
),
expected_arguments=simple_expected_arguments,
),
ToolCallCase(
name="nested",
description="嵌套对象与数组参数校验",
tool_definition=_build_function_tool(
name="record_nested_probe",
description="记录嵌套参数探测结果",
parameters=nested_parameters,
),
expected_arguments=nested_expected_arguments,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名称到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
"""序列化工具调用结果。"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def _is_integer_value(value: Any) -> bool:
"""判断是否为整数类型且排除布尔值。"""
return isinstance(value, int) and not isinstance(value, bool)
def _is_number_value(value: Any) -> bool:
"""判断是否为数值类型且排除布尔值。"""
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
def _schema_type(schema: Dict[str, Any]) -> str:
"""解析 Schema 的类型。"""
schema_type = str(schema.get("type", "") or "").strip()
if schema_type:
return schema_type
if "properties" in schema or "required" in schema:
return "object"
return ""
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
"""按简化 JSON Schema 校验工具参数。"""
errors: List[str] = []
schema_type = _schema_type(schema)
if "enum" in schema and actual_value not in schema["enum"]:
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
if schema_type == "string":
if not isinstance(actual_value, str):
errors.append(f"{path} 类型错误,期望 string实际为 {type(actual_value).__name__}")
return errors
if schema_type == "integer":
if not _is_integer_value(actual_value):
errors.append(f"{path} 类型错误,期望 integer实际为 {type(actual_value).__name__}")
return errors
if schema_type == "number":
if not _is_number_value(actual_value):
errors.append(f"{path} 类型错误,期望 number实际为 {type(actual_value).__name__}")
return errors
if schema_type == "boolean":
if not isinstance(actual_value, bool):
errors.append(f"{path} 类型错误,期望 boolean实际为 {type(actual_value).__name__}")
return errors
if schema_type == "array":
if not isinstance(actual_value, list):
errors.append(f"{path} 类型错误,期望 array实际为 {type(actual_value).__name__}")
return errors
item_schema = schema.get("items")
if isinstance(item_schema, dict):
for index, item in enumerate(actual_value):
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
return errors
if schema_type == "object":
if not isinstance(actual_value, dict):
errors.append(f"{path} 类型错误,期望 object实际为 {type(actual_value).__name__}")
return errors
properties = schema.get("properties", {})
required_fields = [str(item) for item in schema.get("required", [])]
for required_field in required_fields:
if required_field not in actual_value:
errors.append(f"{path}.{required_field} 缺少必填字段")
for field_name, field_value in actual_value.items():
field_path = f"{path}.{field_name}"
field_schema = properties.get(field_name)
if isinstance(field_schema, dict):
errors.extend(_validate_schema(field_schema, field_value, field_path))
continue
additional_properties = schema.get("additionalProperties", True)
if additional_properties is False:
errors.append(f"{field_path} 是未定义字段")
elif isinstance(additional_properties, dict):
errors.extend(_validate_schema(additional_properties, field_value, field_path))
return errors
return errors
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
"""递归比较实际值与期望值是否完全一致。"""
errors: List[str] = []
if isinstance(expected_value, dict):
if not isinstance(actual_value, dict):
return [f"{path} 值不一致,期望 object实际为 {type(actual_value).__name__}"]
expected_keys = set(expected_value.keys())
actual_keys = set(actual_value.keys())
for missing_key in sorted(expected_keys - actual_keys):
errors.append(f"{path}.{missing_key} 缺少期望字段")
for extra_key in sorted(actual_keys - expected_keys):
errors.append(f"{path}.{extra_key} 出现了额外字段")
for shared_key in sorted(expected_keys & actual_keys):
errors.extend(
_compare_expected_values(
expected_value[shared_key],
actual_value[shared_key],
f"{path}.{shared_key}",
)
)
return errors
if isinstance(expected_value, list):
if not isinstance(actual_value, list):
return [f"{path} 值不一致,期望 array实际为 {type(actual_value).__name__}"]
if len(expected_value) != len(actual_value):
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
for index, (expected_item, actual_item) in enumerate(
zip(expected_value, actual_value, strict=False)
):
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
return errors
if isinstance(expected_value, bool):
if not isinstance(actual_value, bool) or actual_value is not expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if _is_integer_value(expected_value):
if not _is_integer_value(actual_value) or actual_value != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if isinstance(expected_value, float):
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if expected_value != actual_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
"""优先选择同名工具调用,否则回退到第一条。"""
for tool_call in tool_calls:
if tool_call.func_name == expected_tool_name:
return tool_call
return tool_calls[0]
def _validate_service_result(
service_result: LLMServiceResult,
target: ProbeTarget,
case: ToolCallCase,
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务层返回结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
return errors, warnings, serialized_tool_calls
if completion.model_name and completion.model_name != target.model_name:
errors.append(
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
)
tool_calls = completion.tool_calls or []
if not tool_calls:
errors.append("模型未返回 tool_calls")
if completion.response.strip():
warnings.append("模型返回了自然语言文本而不是工具调用")
return errors, warnings, serialized_tool_calls
if len(tool_calls) != 1:
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls预期为 1 个")
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
if selected_tool_call.func_name != case.tool_name:
errors.append(
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
)
actual_arguments = selected_tool_call.args
if not isinstance(actual_arguments, dict):
errors.append("工具参数未被解析为对象")
return errors, warnings, serialized_tool_calls
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
if completion.response.strip():
warnings.append("模型同时返回了自然语言文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
case: ToolCallCase,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次工具调用参数探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
prompt=case.build_messages(),
tool_options=[case.tool_definition],
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=case.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
"""根据参数筛选测试用例。"""
all_cases = {case.name: case for case in _build_default_cases()}
if not case_filters:
return list(all_cases.values())
selected_cases: List[ToolCallCase] = []
for case_name in case_filters:
if case_name not in all_cases:
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
selected_cases.append(all_cases[case_name])
return selected_cases
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_cases = _select_cases(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for case in selected_cases:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={case.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
case=case,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
)
parser.add_argument(
"--case",
action="append",
help="指定测试用例名,可选 simple、nested不传则运行全部默认用例",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个用例重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0 以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的自然语言文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,777 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ToolCallScenario:
"""工具调用 API 场景定义。"""
name: str
description: str
prompt: List[Dict[str, Any]]
tool_options: List[Dict[str, Any]] | None = None
expect_tool_calls: bool | None = None
@dataclass(slots=True)
class ProbeResult:
"""单次 API 探测结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
response_text: str = ""
reasoning_text: str = ""
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切换为 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_probe_tools() -> List[Dict[str, Any]]:
"""构造通用测试工具。"""
weather_tool = _build_function_tool(
name="lookup_weather",
description="查询指定城市天气。",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名"},
"unit": {
"type": "string",
"description": "温度单位",
"enum": ["celsius", "fahrenheit"],
},
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
},
"required": ["city", "unit", "include_forecast"],
"additionalProperties": False,
},
)
search_tool = _build_function_tool(
name="search_docs",
description="搜索内部知识库。",
parameters={
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"},
"top_k": {"type": "integer", "description": "返回条数"},
"filters": {
"type": "object",
"description": "过滤条件",
"properties": {
"scope": {"type": "string", "description": "搜索范围"},
"tag": {"type": "string", "description": "标签"},
},
"required": ["scope", "tag"],
"additionalProperties": False,
},
},
"required": ["query", "top_k", "filters"],
"additionalProperties": False,
},
)
return [weather_tool, search_tool]
def _build_default_scenarios() -> List[ToolCallScenario]:
"""构造默认测试场景。"""
tools = _build_probe_tools()
weather_tool = tools[0]
search_tool = tools[1]
history_tool_call = {
"id": "call_hist_weather_001",
"type": "function",
"function": {
"name": "lookup_weather",
"arguments": {
"city": "上海",
"unit": "celsius",
"include_forecast": True,
},
},
}
nested_history_tool_call = {
"id": "call_hist_search_001",
"type": "function",
"function": {
"name": "search_docs",
"arguments": {
"query": "工具调用兼容性",
"top_k": 3,
"filters": {
"scope": "internal",
"tag": "tool-call",
},
},
},
}
return [
ToolCallScenario(
name="fresh_tool_call",
description="首轮普通工具调用请求。",
prompt=[
{
"role": "system",
"content": (
"你正在执行工具调用连通性测试。"
"如果能调用工具,就优先调用最合适的工具。"
),
},
{
"role": "user",
"content": "请查询上海天气,并使用工具给出参数。",
},
],
tool_options=[weather_tool],
expect_tool_calls=True,
),
ToolCallScenario(
name="history_assistant_tool_calls_with_content",
description="历史 assistant 同时包含文本和 tool_calls当前轮不再提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"content": "我先查询天气,再继续回答。",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续说,别丢掉上下文。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_assistant_tool_calls_without_content",
description="历史 assistant 只有 tool_calls没有文本内容。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_result_followup",
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
prompt=[
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查询天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "多云",
"temperature_c": 24,
"forecast": ["", "小雨"],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "结合上面的查询结果继续总结。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_multiple_tool_calls_and_results",
description="历史中包含多个 tool_calls 与多条 tool 结果。",
prompt=[
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
{
"role": "assistant",
"content": "我分两步查询。",
"tool_calls": [history_tool_call, nested_history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 22,
},
ensure_ascii=False,
),
},
{
"role": "tool",
"tool_call_id": "call_hist_search_001",
"content": json.dumps(
{
"items": [
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
"部分 provider 在历史消息回放时兼容性较弱",
],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "继续整合上面的两个结果。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_calls_with_current_tools",
description="保留历史 tool_calls同时当前轮仍然提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 26,
},
ensure_ascii=False,
),
},
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
],
tool_options=[search_tool],
expect_tool_calls=True,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
"""序列化返回中的工具调用。"""
if not tool_calls:
return []
serialized_items: List[Dict[str, Any]] = []
for tool_call in tool_calls:
serialized_items.append(
{
"id": getattr(tool_call, "call_id", ""),
"function": {
"name": getattr(tool_call, "func_name", ""),
"arguments": dict(getattr(tool_call, "args", {}) or {}),
},
**(
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
if getattr(tool_call, "extra_content", None)
else {}
),
}
)
return serialized_items
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
return errors, warnings, serialized_tool_calls
if scenario.expect_tool_calls is True and not serialized_tool_calls:
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
if scenario.expect_tool_calls is False and serialized_tool_calls:
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
if completion.response.strip():
warnings.append("模型返回了可见文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
scenario: ToolCallScenario,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次 API 探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
prompt=scenario.prompt,
tool_options=scenario.tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=scenario.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
"""按名称筛选测试场景。"""
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
if not case_filters:
return list(all_scenarios.values())
selected_scenarios: List[ToolCallScenario] = []
for case_name in case_filters:
if case_name not in all_scenarios:
raise ValueError(
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
)
selected_scenarios.append(all_scenarios[case_name])
return selected_scenarios
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_scenarios = _select_scenarios(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for scenario in selected_scenarios:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={scenario.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
scenario=scenario,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
)
parser.add_argument(
"--case",
action="append",
help=(
"指定测试场景名,可选值包括 "
"fresh_tool_call、history_assistant_tool_calls_with_content、"
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
),
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个场景重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0,以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的可见文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
config_manager.initialize()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,83 @@
#!/usr/bin/env bash
set -euo pipefail
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
DASHBOARD_ROOT="$REPO_ROOT/dashboard"
OUTPUT_DIR="${MAIBOT_UI_SNAPSHOT_DIR:-$REPO_ROOT/tmp/ui-snapshots/a_memorix-electron}"
PYTHON_BIN="${MAIBOT_PYTHON_BIN:-$REPO_ROOT/../../.venv/bin/python}"
ELECTRON_BIN="${MAIBOT_ELECTRON_BIN:-$DASHBOARD_ROOT/node_modules/electron/dist/Electron.app/Contents/MacOS/Electron}"
DRIVER_SCRIPT="$DASHBOARD_ROOT/scripts/a_memorix_electron_validate.cjs"
BACKEND_SCRIPT="$REPO_ROOT/scripts/run_a_memorix_webui_backend.py"
BACKEND_HOST="${MAIBOT_WEBUI_HOST:-127.0.0.1}"
BACKEND_PORT="${MAIBOT_WEBUI_PORT:-8001}"
DASHBOARD_HOST="${MAIBOT_DASHBOARD_HOST:-127.0.0.1}"
DASHBOARD_PORT="${MAIBOT_DASHBOARD_PORT:-7999}"
BACKEND_URL="http://${BACKEND_HOST}:${BACKEND_PORT}"
DASHBOARD_URL="http://${DASHBOARD_HOST}:${DASHBOARD_PORT}"
REUSE_SERVICES="${MAIBOT_UI_REUSE_SERVICES:-0}"
BACKEND_PID=""
DASHBOARD_PID=""
mkdir -p "$OUTPUT_DIR"
cleanup() {
local exit_code=$?
if [[ -n "$DASHBOARD_PID" ]] && kill -0 "$DASHBOARD_PID" >/dev/null 2>&1; then
kill "$DASHBOARD_PID" >/dev/null 2>&1 || true
wait "$DASHBOARD_PID" >/dev/null 2>&1 || true
fi
if [[ -n "$BACKEND_PID" ]] && kill -0 "$BACKEND_PID" >/dev/null 2>&1; then
kill "$BACKEND_PID" >/dev/null 2>&1 || true
wait "$BACKEND_PID" >/dev/null 2>&1 || true
fi
exit "$exit_code"
}
trap cleanup EXIT
wait_for_url() {
local url="$1"
local label="$2"
local timeout="${3:-60}"
local started_at
started_at="$(date +%s)"
while true; do
if env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
curl -fsS "$url" >/dev/null 2>&1; then
return 0
fi
if (( "$(date +%s)" - started_at >= timeout )); then
echo "Timed out waiting for ${label}: ${url}" >&2
return 1
fi
sleep 1
done
}
if [[ "$REUSE_SERVICES" != "1" ]]; then
if ! env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
curl -fsS "${BACKEND_URL}/api/webui/health" >/dev/null 2>&1; then
(
cd "$REPO_ROOT"
WEBUI_HOST="$BACKEND_HOST" WEBUI_PORT="$BACKEND_PORT" "$PYTHON_BIN" "$BACKEND_SCRIPT"
) >"$OUTPUT_DIR/backend.log" 2>&1 &
BACKEND_PID="$!"
wait_for_url "${BACKEND_URL}/api/webui/health" "MaiBot WebUI backend" 120
fi
if ! env -u HTTP_PROXY -u HTTPS_PROXY -u ALL_PROXY NO_PROXY=127.0.0.1,localhost \
curl -fsS "${DASHBOARD_URL}/auth" >/dev/null 2>&1; then
(
cd "$DASHBOARD_ROOT"
npm run dev -- --host "$DASHBOARD_HOST" --port "$DASHBOARD_PORT"
) >"$OUTPUT_DIR/dashboard.log" 2>&1 &
DASHBOARD_PID="$!"
wait_for_url "${DASHBOARD_URL}/auth" "dashboard dev server" 120
fi
fi
env -u ELECTRON_RUN_AS_NODE \
MAIBOT_DASHBOARD_URL="$DASHBOARD_URL" \
MAIBOT_UI_SNAPSHOT_DIR="$OUTPUT_DIR" \
"$ELECTRON_BIN" "$DRIVER_SCRIPT"