diff --git a/scripts/analyze_tool_usage_by_chat.py b/scripts/analyze_tool_usage_by_chat.py new file mode 100644 index 00000000..2ada9a82 --- /dev/null +++ b/scripts/analyze_tool_usage_by_chat.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from argparse import ArgumentParser, Namespace +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import DefaultDict + +import csv +import json +import sqlite3 +import sys + + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +DEFAULT_DB_PATH = PROJECT_ROOT / "data" / "MaiBot.db" + + +@dataclass(frozen=True) +class ToolUsageRow: + chat_id: str + tool_name: str + count: int + chat_total: int + percent_in_chat: float + percent_in_all: float + + +def parse_datetime_filter(value: str | None) -> str | None: + if value is None: + return None + + normalized_value = value.strip() + if not normalized_value: + return None + + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d"): + try: + parsed = datetime.strptime(normalized_value, fmt) + except ValueError: + continue + return parsed.strftime("%Y-%m-%d %H:%M:%S") + + raise ValueError(f"无法解析时间: {value!r},请使用 YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS") + + +def connect_readonly(db_path: Path) -> sqlite3.Connection: + if not db_path.exists(): + raise FileNotFoundError(f"数据库文件不存在: {db_path}") + + database_uri = f"file:{db_path.as_posix()}?mode=ro" + connection = sqlite3.connect(database_uri, uri=True) + connection.row_factory = sqlite3.Row + connection.execute("PRAGMA busy_timeout=5000") + return connection + + +def fetch_tool_counts( + db_path: Path, + since: str | None, + until: str | None, + include_empty_chat_id: bool, + include_empty_tool_name: bool, +) -> list[tuple[str, str, int]]: + where_clauses: list[str] = [] + params: list[str] = [] + + if since is not None: + where_clauses.append("timestamp >= ?") + params.append(since) + if until is not None: + where_clauses.append("timestamp < ?") + params.append(until) + if not include_empty_chat_id: + where_clauses.append("COALESCE(session_id, '') != ''") + if not include_empty_tool_name: + where_clauses.append("COALESCE(tool_name, '') != ''") + + where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + query = f""" + SELECT + COALESCE(session_id, '') AS chat_id, + COALESCE(tool_name, '') AS tool_name, + COUNT(*) AS usage_count + FROM tool_records + {where_sql} + GROUP BY COALESCE(session_id, ''), COALESCE(tool_name, '') + ORDER BY COALESCE(session_id, ''), usage_count DESC, COALESCE(tool_name, '') + """ + + with connect_readonly(db_path) as connection: + rows = connection.execute(query, params).fetchall() + + return [(str(row["chat_id"]), str(row["tool_name"]), int(row["usage_count"])) for row in rows] + + +def build_usage_rows( + counts: list[tuple[str, str, int]], + min_chat_total: int, + top_tools_per_chat: int | None, +) -> list[ToolUsageRow]: + chat_totals: DefaultDict[str, int] = defaultdict(int) + for chat_id, _tool_name, count in counts: + chat_totals[chat_id] += count + + all_total = sum(chat_totals.values()) + rows: list[ToolUsageRow] = [] + emitted_per_chat: DefaultDict[str, int] = defaultdict(int) + + sorted_counts = sorted(counts, key=lambda item: (item[0], -item[2], item[1])) + for chat_id, tool_name, count in sorted_counts: + chat_total = chat_totals[chat_id] + if chat_total < min_chat_total: + continue + if top_tools_per_chat is not None and emitted_per_chat[chat_id] >= top_tools_per_chat: + continue + + emitted_per_chat[chat_id] += 1 + rows.append( + ToolUsageRow( + chat_id=chat_id, + tool_name=tool_name, + count=count, + chat_total=chat_total, + percent_in_chat=count / chat_total * 100 if chat_total else 0.0, + percent_in_all=count / all_total * 100 if all_total else 0.0, + ) + ) + + return rows + + +def build_overall_rows(counts: list[tuple[str, str, int]], min_chat_total: int) -> list[tuple[str, int, float]]: + chat_totals: DefaultDict[str, int] = defaultdict(int) + for chat_id, _tool_name, count in counts: + chat_totals[chat_id] += count + + tool_counts: DefaultDict[str, int] = defaultdict(int) + for chat_id, tool_name, count in counts: + if chat_totals[chat_id] < min_chat_total: + continue + tool_counts[tool_name] += count + + total = sum(tool_counts.values()) + sorted_items = sorted(tool_counts.items(), key=lambda item: (-item[1], item[0])) + return [(tool_name, count, count / total * 100 if total else 0.0) for tool_name, count in sorted_items] + + +def print_overall_block(overall_rows: list[tuple[str, int, float]]) -> None: + print("全部统计") + total = sum(count for _tool_name, count, _percent in overall_rows) + print(f"tool_total: {total}") + if not overall_rows: + print(" 无工具调用记录") + return + + tool_width = max(len("tool"), *(len(tool_name) for tool_name, _count, _percent in overall_rows)) + count_width = max(len("count"), *(len(str(count)) for _tool_name, count, _percent in overall_rows)) + percent_width = max(len("全局占比"), *(len(f"{percent:.2f}%") for _tool_name, _count, percent in overall_rows)) + + print(f" {'tool':<{tool_width}} {'count':>{count_width}} {'全局占比':>{percent_width}}") + print(f" {'-' * tool_width} {'-' * count_width} {'-' * percent_width}") + for tool_name, count, percent in overall_rows: + print(f" {tool_name:<{tool_width}} {count:>{count_width}} {percent:>{percent_width - 1}.2f}%") + + +def print_markdown(rows: list[ToolUsageRow], overall_rows: list[tuple[str, int, float]]) -> None: + print_overall_block(overall_rows) + if rows: + print() + + grouped_rows: DefaultDict[str, list[ToolUsageRow]] = defaultdict(list) + for row in rows: + grouped_rows[row.chat_id].append(row) + + first_group = True + for chat_id in sorted(grouped_rows): + chat_rows = grouped_rows[chat_id] + if not chat_rows: + continue + + if not first_group: + print() + first_group = False + + chat_total = chat_rows[0].chat_total + print(f"chat_id: {chat_id}") + print(f"tool_total: {chat_total}") + + tool_width = max(len("tool"), *(len(row.tool_name) for row in chat_rows)) + count_width = max(len("count"), *(len(str(row.count)) for row in chat_rows)) + chat_percent_width = max(len("chat内占比"), *(len(f"{row.percent_in_chat:.2f}%") for row in chat_rows)) + all_percent_width = max(len("全局占比"), *(len(f"{row.percent_in_all:.2f}%") for row in chat_rows)) + + header = ( + f" {'tool':<{tool_width}} " + f"{'count':>{count_width}} " + f"{'chat内占比':>{chat_percent_width}} " + f"{'全局占比':>{all_percent_width}}" + ) + print(header) + print( + f" {'-' * tool_width} " + f"{'-' * count_width} " + f"{'-' * chat_percent_width} " + f"{'-' * all_percent_width}" + ) + for row in chat_rows: + print( + f" {row.tool_name:<{tool_width}} " + f"{row.count:>{count_width}} " + f"{row.percent_in_chat:>{chat_percent_width - 1}.2f}% " + f"{row.percent_in_all:>{all_percent_width - 1}.2f}%" + ) + + +def print_json(rows: list[ToolUsageRow]) -> None: + payload = [ + { + "chat_id": row.chat_id, + "tool_name": row.tool_name, + "count": row.count, + "chat_total": row.chat_total, + "percent_in_chat": round(row.percent_in_chat, 4), + "percent_in_all": round(row.percent_in_all, 4), + } + for row in rows + ] + print(json.dumps(payload, ensure_ascii=False, indent=2)) + + +def print_csv(rows: list[ToolUsageRow]) -> None: + writer = csv.writer(sys.stdout) + writer.writerow(["chat_id", "tool_name", "count", "chat_total", "percent_in_chat", "percent_in_all"]) + for row in rows: + writer.writerow( + [ + row.chat_id, + row.tool_name, + row.count, + row.chat_total, + f"{row.percent_in_chat:.4f}", + f"{row.percent_in_all:.4f}", + ] + ) + + +def parse_args() -> Namespace: + parser = ArgumentParser(description="统计不同 chat_id 的工具使用次数和占比。") + parser.add_argument("--db", type=Path, default=DEFAULT_DB_PATH, help=f"数据库路径,默认: {DEFAULT_DB_PATH}") + parser.add_argument("--since", help="仅统计此时间之后的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS") + parser.add_argument("--until", help="仅统计此时间之前的记录,格式: YYYY-MM-DD 或 YYYY-MM-DD HH:MM:SS") + parser.add_argument("--min-chat-total", type=int, default=1, help="只显示工具调用总数不低于该值的 chat_id") + parser.add_argument("--top-tools", type=int, help="每个 chat_id 最多显示前 N 个工具") + parser.add_argument("--format", choices=("markdown", "json", "csv"), default="markdown", help="输出格式,markdown 为按 chat_id 分块的终端表") + parser.add_argument("--include-empty-chat-id", action="store_true", help="包含 chat_id 为空的记录") + parser.add_argument("--include-empty-tool-name", action="store_true", help="包含 tool_name 为空的记录") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + since = parse_datetime_filter(args.since) + until = parse_datetime_filter(args.until) + min_chat_total = max(1, int(args.min_chat_total)) + top_tools = args.top_tools if args.top_tools is None else max(1, int(args.top_tools)) + + counts = fetch_tool_counts( + db_path=args.db.resolve(), + since=since, + until=until, + include_empty_chat_id=args.include_empty_chat_id, + include_empty_tool_name=args.include_empty_tool_name, + ) + rows = build_usage_rows( + counts=counts, + min_chat_total=min_chat_total, + top_tools_per_chat=top_tools, + ) + + if args.format == "json": + print_json(rows) + elif args.format == "csv": + print_csv(rows) + else: + overall_rows = build_overall_rows(counts, min_chat_total=min_chat_total) + print_markdown(rows, overall_rows) + + +if __name__ == "__main__": + main() diff --git a/src/config/config.py b/src/config/config.py index 8eef344e..8859816c 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -57,7 +57,7 @@ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute( LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute() A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute() MMC_VERSION: str = "1.0.0-pre.14" -CONFIG_VERSION: str = "8.10.11" +CONFIG_VERSION: str = "8.10.12" MODEL_CONFIG_VERSION: str = "1.16.0" logger = get_logger("config") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 20e0e694..c3660be2 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -421,6 +421,22 @@ class ChatConfig(ConfigBase): ) """Planner 连续被新消息打断的最大次数,0 表示不启用打断""" + timing_gate_non_continue_cooldown_seconds: float = Field( + default=0, + ge=0, + json_schema_extra={ + "label": { + "zh_CN": "Timing Gate 非 continue 冷却", + "en_US": "Timing Gate non-continue cooldown", + "ja_JP": "Timing Gate 非 continue クールダウン", + }, + "x-widget": "input", + "x-icon": "timer", + "advanced": True, + }, + ) + """Timing Gate 返回 wait/no_reply 时的最小窗口秒数,0 表示不启用冷却""" + group_chat_prompt: str = Field( default=( "你正在qq群里聊天,下面是群里正在聊的内容,其中包含聊天记录和聊天中的图片和表情包。\n" diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 117673ea..4cb700c2 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -511,6 +511,11 @@ class MaisakaReasoningEngine: timing_tool_results, timing_tool_monitor_results, ) = await self._run_timing_gate(anchor_message) + timing_elapsed_seconds = time.time() - timing_started_at + if timing_action != "continue": + await self._runtime._wait_for_timing_gate_non_continue_cooldown( + timing_elapsed_seconds + ) timing_duration_ms = (time.time() - timing_started_at) * 1000 cycle_detail.time_records["timing_gate"] = timing_duration_ms / 1000 await emit_timing_gate_result( diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 0d13136f..4701d7ce 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -116,6 +116,10 @@ class MaisakaHeartFlowChatting: self._force_next_timing_continue = False self._force_next_timing_message_id = "" self._force_next_timing_reason = "" + self._timing_gate_non_continue_cooldown_seconds = max( + 0.0, + float(global_config.chat.timing_gate_non_continue_cooldown_seconds), + ) self._planner_interrupt_flag: Optional[asyncio.Event] = None self._planner_interrupt_requested = False self._planner_interrupt_consecutive_count = 0 @@ -567,6 +571,20 @@ class MaisakaHeartFlowChatting: return self._force_next_timing_continue + async def _wait_for_timing_gate_non_continue_cooldown(self, elapsed_seconds: float) -> None: + """仅对 Timing Gate 的 wait/no_reply 动作应用冷却窗口。""" + + cooldown_seconds = self._timing_gate_non_continue_cooldown_seconds + if cooldown_seconds <= 0: + return + + remaining_seconds = cooldown_seconds - max(0.0, elapsed_seconds) + if remaining_seconds <= 0: + return + + logger.info(f"{self.log_prefix} Timing Gate 非 continue 冷却中,等待 {remaining_seconds:.2f} 秒后结束") + await asyncio.sleep(remaining_seconds) + def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None: """绑定当前可打断请求使用的中断标记。""" self._planner_interrupt_flag = interrupt_flag