")
+ nested_diff = _find_first_structural_diff(previous_value[index], current_value[index], index_path)
+ if nested_diff is not None:
+ return nested_diff
+ return None
+
+ if previous_value == current_value:
+ return None
+
+ if isinstance(previous_value, str) and isinstance(current_value, str):
+ diff_index = _longest_common_prefix_length(previous_value, current_value)
+ return _DynamicDiff(
+ f"{path}@char{diff_index}",
+ _summarize_value(previous_value[diff_index:]),
+ _summarize_value(current_value[diff_index:]),
+ )
+
+ return _DynamicDiff(path, _summarize_value(previous_value), _summarize_value(current_value))
+
+
+def _diagnose_dynamic_diff(previous_prompt_text: str | None, current_prompt_text: str | None) -> _DynamicDiff:
+ if not current_prompt_text:
+ return _DynamicDiff("prompt_text.unavailable", "", "")
+ if not previous_prompt_text:
+ return _DynamicDiff("cache_pool.empty", "", _summarize_value(current_prompt_text))
+
+ try:
+ previous_payload = json.loads(previous_prompt_text)
+ current_payload = json.loads(current_prompt_text)
+ except json.JSONDecodeError:
+ diff_index = _longest_common_prefix_length(previous_prompt_text, current_prompt_text)
+ return _DynamicDiff(
+ f"raw_prompt@char{diff_index}",
+ _summarize_value(previous_prompt_text[diff_index:]),
+ _summarize_value(current_prompt_text[diff_index:]),
+ )
+
+ diff = _find_first_structural_diff(previous_payload, current_payload)
+ if diff is None:
+ return _DynamicDiff("identical", "", "")
+ return diff
+
+
+def _load_prompt_payload(prompt_text: str | None) -> dict[str, Any] | None:
+ if not prompt_text:
+ return None
+ try:
+ payload = json.loads(prompt_text)
+ except json.JSONDecodeError:
+ return None
+ return payload if isinstance(payload, dict) else None
+
+
+def _extract_prompt_messages(prompt_text: str | None) -> list[dict[str, Any]]:
+ payload = _load_prompt_payload(prompt_text)
+ if payload is None:
+ return []
+ messages = payload.get("messages")
+ return [message for message in messages if isinstance(message, dict)] if isinstance(messages, list) else []
+
+
+def _message_fingerprints(messages: list[dict[str, Any]]) -> list[str]:
+ return [json.dumps(message, ensure_ascii=False, sort_keys=True, default=str) for message in messages]
+
+
+def _count_common_prefix_items(left_items: list[str], right_items: list[str]) -> int:
+ common_count = 0
+ for left_item, right_item in zip(left_items, right_items, strict=False):
+ if left_item != right_item:
+ break
+ common_count += 1
+ return common_count
+
+
+def _count_common_suffix_items(left_items: list[str], right_items: list[str]) -> int:
+ common_count = 0
+ max_count = min(len(left_items), len(right_items))
+ while common_count < max_count and left_items[-common_count - 1] == right_items[-common_count - 1]:
+ common_count += 1
+ return common_count
+
+
+def _find_longest_message_alignment(previous_items: list[str], current_items: list[str]) -> tuple[int, int, int]:
+ best_overlap = 0
+ best_previous_start = 0
+ best_current_start = 0
+ for previous_start in range(len(previous_items)):
+ for current_start in range(len(current_items)):
+ overlap = 0
+ while (
+ previous_start + overlap < len(previous_items)
+ and current_start + overlap < len(current_items)
+ and previous_items[previous_start + overlap] == current_items[current_start + overlap]
+ ):
+ overlap += 1
+ if overlap > best_overlap:
+ best_overlap = overlap
+ best_previous_start = previous_start
+ best_current_start = current_start
+ return best_overlap, best_previous_start, best_current_start
+
+
+def _get_message_role(messages: list[dict[str, Any]], index: int) -> str:
+ if not messages:
+ return ""
+ try:
+ value = messages[index].get("role", "")
+ except IndexError:
+ return ""
+ return str(value or "")
+
+
+def _diagnose_prompt_cache_details(
+ *,
+ previous_prompt_text: str | None,
+ current_prompt_text: str | None,
+ common_prefix_chars: int,
+) -> _PromptCacheDiagnostics:
+ current_messages = _extract_prompt_messages(current_prompt_text)
+ previous_messages = _extract_prompt_messages(previous_prompt_text)
+ current_items = _message_fingerprints(current_messages)
+ previous_items = _message_fingerprints(previous_messages)
+ current_prompt_length = len(current_prompt_text or "")
+ previous_prompt_length = len(previous_prompt_text or "")
+ common_prefix_rate = common_prefix_chars / current_prompt_length * 100 if current_prompt_length > 0 else 0.0
+
+ common_prefix_messages = _count_common_prefix_items(previous_items, current_items)
+ common_suffix_messages = _count_common_suffix_items(previous_items, current_items)
+ aligned_overlap, aligned_previous_start, aligned_current_start = _find_longest_message_alignment(
+ previous_items,
+ current_items,
+ )
+ suspected_context_sliding = (
+ aligned_previous_start > aligned_current_start
+ and aligned_overlap > common_prefix_messages
+ )
+ sliding_dropped_head_messages = aligned_previous_start - aligned_current_start if suspected_context_sliding else 0
+
+ return _PromptCacheDiagnostics(
+ current_message_count=len(current_messages),
+ best_match_message_count=len(previous_messages),
+ common_prefix_messages=common_prefix_messages,
+ common_suffix_messages=common_suffix_messages,
+ common_prefix_rate=common_prefix_rate,
+ prompt_growth_chars=current_prompt_length - previous_prompt_length,
+ longest_aligned_message_overlap=aligned_overlap,
+ aligned_previous_start_index=aligned_previous_start,
+ aligned_current_start_index=aligned_current_start,
+ suspected_context_sliding=suspected_context_sliding,
+ sliding_dropped_head_messages=sliding_dropped_head_messages,
+ sliding_aligned_messages=aligned_overlap if suspected_context_sliding else 0,
+ sliding_new_tail_messages=(
+ max(len(current_messages) - aligned_current_start - aligned_overlap, 0)
+ if suspected_context_sliding
+ else 0
+ ),
+ current_first_message_role=_get_message_role(current_messages, 0),
+ best_first_message_role=_get_message_role(previous_messages, 0),
+ current_last_message_role=_get_message_role(current_messages, -1),
+ best_last_message_role=_get_message_role(previous_messages, -1),
+ )
+
+
+def _get_usage_log_path(now: datetime) -> Path:
+ return CACHE_STATS_DIR / f"usage_{now:%Y%m%d}.jsonl"
+
+
+def _get_report_path() -> Path:
+ return CACHE_STATS_DIR / REPORT_FILE_NAME
+
+
+def _get_session_report_path() -> Path:
+ return CACHE_STATS_DIR / SESSION_REPORT_FILE_NAME
+
+
+def _iter_usage_log_paths() -> list[Path]:
+ if not CACHE_STATS_DIR.exists():
+ return []
+ return sorted(CACHE_STATS_DIR.glob("usage_*.jsonl"))
+
+
+def _read_usage_events() -> list[dict[str, Any]]:
+ events: list[dict[str, Any]] = []
+ for file_path in _iter_usage_log_paths():
+ try:
+ lines = file_path.read_text(encoding="utf-8").splitlines()
+ except OSError:
+ continue
+ for line in lines:
+ if not line.strip():
+ continue
+ try:
+ event = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ if isinstance(event, dict):
+ events.append(event)
+ return events
+
+
+def _write_json_line(file_path: Path, payload: Dict[str, int | str | float | bool]) -> None:
+ CACHE_STATS_DIR.mkdir(parents=True, exist_ok=True)
+ with file_path.open("a", encoding="utf-8") as file:
+ file.write(json.dumps(payload, ensure_ascii=False) + "\n")
+
+
+def _format_int(value: int | str | float) -> str:
+ return f"{int(value):,}"
+
+
+def _format_rate(value: int | str | float) -> str:
+ return f"{float(value):.2f}%"
+
+
+def _calculate_rate(hit_tokens: int, miss_tokens: int) -> float:
+ total_tokens = hit_tokens + miss_tokens
+ return hit_tokens / total_tokens * 100 if total_tokens > 0 else 0.0
+
+
+def _normal_cdf(value: float) -> float:
+ return 0.5 * (1.0 + erf(value / sqrt(2.0)))
+
+
+def _confidence_from_z_score(z_score: float) -> float:
+ p_value = 2.0 * (1.0 - _normal_cdf(abs(z_score)))
+ return max(0.0, min(100.0, (1.0 - p_value) * 100.0))
+
+
+def _format_significance_label(confidence: float, *, min_confidence: float = 95.0) -> str:
+ return "显著" if confidence >= min_confidence else "不显著"
+
+
+def _calculate_two_proportion_confidence(
+ *,
+ current_hit: int,
+ current_total: int,
+ baseline_hit: int,
+ baseline_total: int,
+) -> float:
+ if current_total <= 0 or baseline_total <= 0:
+ return 0.0
+ current_rate = current_hit / current_total
+ baseline_rate = baseline_hit / baseline_total
+ pooled_rate = (current_hit + baseline_hit) / (current_total + baseline_total)
+ standard_error = sqrt(pooled_rate * (1.0 - pooled_rate) * (1.0 / current_total + 1.0 / baseline_total))
+ if standard_error <= 0:
+ return 0.0
+ return _confidence_from_z_score((current_rate - baseline_rate) / standard_error)
+
+
+def _calculate_sample_variance(*, value_total: float, square_total: float, count: int) -> float:
+ if count <= 1:
+ return 0.0
+ return max((square_total - (value_total * value_total / count)) / (count - 1), 0.0)
+
+
+def _calculate_mean_difference_confidence(
+ *,
+ current_mean: float,
+ current_variance: float,
+ current_count: int,
+ baseline_mean: float,
+ baseline_variance: float,
+ baseline_count: int,
+) -> float:
+ if current_count <= 1 or baseline_count <= 1:
+ return 0.0
+ standard_error = sqrt(current_variance / current_count + baseline_variance / baseline_count)
+ if standard_error <= 0:
+ return 0.0
+ return _confidence_from_z_score((current_mean - baseline_mean) / standard_error)
+
+
+def _normalize_event_run_id(event: dict[str, Any]) -> str:
+ run_id = str(event.get("run_id") or "").strip()
+ return run_id or "legacy"
+
+
+def _aggregate_usage_events_by_run(events: list[dict[str, Any]]) -> list[dict[str, int | str | float]]:
+ grouped: dict[str, dict[str, int | str | float]] = {}
+ for event in events:
+ run_id = _normalize_event_run_id(event)
+ item = grouped.setdefault(
+ run_id,
+ {
+ "run_id": run_id,
+ "process_started_at": str(event.get("process_started_at") or ""),
+ "first_seen_at": str(event.get("created_at") or ""),
+ "last_seen_at": str(event.get("created_at") or ""),
+ "calls": 0,
+ "prompt_tokens": 0,
+ "prompt_cache_hit_tokens": 0,
+ "prompt_cache_miss_tokens": 0,
+ "theoretical_prompt_cache_hit_tokens": 0,
+ "theoretical_prompt_cache_miss_tokens": 0,
+ "common_prefix_rate_total": 0.0,
+ "common_prefix_rate_square_total": 0.0,
+ "suspected_context_sliding_calls": 0,
+ },
+ )
+ created_at = str(event.get("created_at") or "")
+ if created_at:
+ if not item["first_seen_at"] or created_at < str(item["first_seen_at"]):
+ item["first_seen_at"] = created_at
+ if created_at > str(item["last_seen_at"]):
+ item["last_seen_at"] = created_at
+ item["calls"] = int(item["calls"]) + 1
+ item["prompt_tokens"] = int(item["prompt_tokens"]) + int(event.get("prompt_tokens") or 0)
+ item["prompt_cache_hit_tokens"] = int(item["prompt_cache_hit_tokens"]) + int(
+ event.get("prompt_cache_hit_tokens") or 0
+ )
+ item["prompt_cache_miss_tokens"] = int(item["prompt_cache_miss_tokens"]) + int(
+ event.get("prompt_cache_miss_tokens") or 0
+ )
+ item["theoretical_prompt_cache_hit_tokens"] = int(item["theoretical_prompt_cache_hit_tokens"]) + int(
+ event.get("theoretical_prompt_cache_hit_tokens") or 0
+ )
+ item["theoretical_prompt_cache_miss_tokens"] = int(item["theoretical_prompt_cache_miss_tokens"]) + int(
+ event.get("theoretical_prompt_cache_miss_tokens") or 0
+ )
+ item["common_prefix_rate_total"] = float(item["common_prefix_rate_total"]) + float(
+ event.get("theoretical_common_prefix_rate") or 0.0
+ )
+ if bool(event.get("suspected_context_sliding", False)):
+ item["suspected_context_sliding_calls"] = int(item["suspected_context_sliding_calls"]) + 1
+
+ result: list[dict[str, int | str | float]] = []
+ for item in grouped.values():
+ calls = int(item["calls"])
+ hit_tokens = int(item["prompt_cache_hit_tokens"])
+ miss_tokens = int(item["prompt_cache_miss_tokens"])
+ theoretical_hit_tokens = int(item["theoretical_prompt_cache_hit_tokens"])
+ theoretical_miss_tokens = int(item["theoretical_prompt_cache_miss_tokens"])
+ item["prompt_cache_hit_rate"] = round(_calculate_rate(hit_tokens, miss_tokens), 2)
+ item["theoretical_prompt_cache_hit_rate"] = round(
+ _calculate_rate(theoretical_hit_tokens, theoretical_miss_tokens),
+ 2,
+ )
+ item["avg_common_prefix_rate"] = round(float(item["common_prefix_rate_total"]) / calls, 2) if calls else 0.0
+ result.append(item)
+
+ return sorted(result, key=lambda item: str(item["first_seen_at"]))
+
+
+def _get_previous_run_id(run_stats: list[dict[str, int | str | float]], current_run_id: str) -> str:
+ run_ids = [str(item["run_id"]) for item in run_stats]
+ if current_run_id not in run_ids:
+ return ""
+ current_index = run_ids.index(current_run_id)
+ if current_index <= 0:
+ return ""
+ return run_ids[current_index - 1]
+
+
+def _aggregate_usage_events_by_call_site(
+ events: list[dict[str, Any]],
+ *,
+ run_id: str,
+ include_session: bool = True,
+) -> dict[tuple[str, ...], dict[str, int | str | float]]:
+ grouped: dict[tuple[str, ...], dict[str, int | str | float]] = {}
+ for event in events:
+ if _normalize_event_run_id(event) != run_id:
+ continue
+ base_key = (
+ str(event.get("task_name") or ""),
+ str(event.get("request_type") or ""),
+ str(event.get("model_name") or ""),
+ )
+ key = (
+ *base_key,
+ _normalize_session_id(str(event.get("session_id") or "")),
+ ) if include_session else base_key
+ item = grouped.setdefault(
+ key,
+ {
+ "task_name": key[0],
+ "request_type": key[1],
+ "model_name": key[2],
+ "session_id": key[3] if include_session else "",
+ "calls": 0,
+ "prompt_cache_hit_tokens": 0,
+ "prompt_cache_miss_tokens": 0,
+ "theoretical_prompt_cache_hit_tokens": 0,
+ "theoretical_prompt_cache_miss_tokens": 0,
+ "common_prefix_rate_total": 0.0,
+ "common_prefix_rate_square_total": 0.0,
+ "suspected_context_sliding_calls": 0,
+ },
+ )
+ item["calls"] = int(item["calls"]) + 1
+ item["prompt_cache_hit_tokens"] = int(item["prompt_cache_hit_tokens"]) + int(
+ event.get("prompt_cache_hit_tokens") or 0
+ )
+ item["prompt_cache_miss_tokens"] = int(item["prompt_cache_miss_tokens"]) + int(
+ event.get("prompt_cache_miss_tokens") or 0
+ )
+ item["theoretical_prompt_cache_hit_tokens"] = int(item["theoretical_prompt_cache_hit_tokens"]) + int(
+ event.get("theoretical_prompt_cache_hit_tokens") or 0
+ )
+ item["theoretical_prompt_cache_miss_tokens"] = int(item["theoretical_prompt_cache_miss_tokens"]) + int(
+ event.get("theoretical_prompt_cache_miss_tokens") or 0
+ )
+ prefix_rate = float(event.get("theoretical_common_prefix_rate") or 0.0)
+ item["common_prefix_rate_total"] = float(item["common_prefix_rate_total"]) + prefix_rate
+ item["common_prefix_rate_square_total"] = float(item["common_prefix_rate_square_total"]) + prefix_rate * prefix_rate
+ if bool(event.get("suspected_context_sliding", False)):
+ item["suspected_context_sliding_calls"] = int(item["suspected_context_sliding_calls"]) + 1
+
+ for item in grouped.values():
+ calls = int(item["calls"])
+ prefix_total = float(item["common_prefix_rate_total"])
+ prefix_square_total = float(item["common_prefix_rate_square_total"])
+ item["prompt_cache_hit_rate"] = round(
+ _calculate_rate(int(item["prompt_cache_hit_tokens"]), int(item["prompt_cache_miss_tokens"])),
+ 2,
+ )
+ item["theoretical_prompt_cache_hit_rate"] = round(
+ _calculate_rate(
+ int(item["theoretical_prompt_cache_hit_tokens"]),
+ int(item["theoretical_prompt_cache_miss_tokens"]),
+ ),
+ 2,
+ )
+ item["avg_common_prefix_rate"] = round(prefix_total / calls, 2) if calls else 0.0
+ item["common_prefix_rate_variance"] = round(
+ _calculate_sample_variance(
+ value_total=prefix_total,
+ square_total=prefix_square_total,
+ count=calls,
+ ),
+ 4,
+ )
+ return grouped
+
+
+def _render_run_rows(run_stats: list[dict[str, int | str | float]], current_run_id: str) -> str:
+ rows: list[str] = []
+ for item in reversed(run_stats[-12:]):
+ current_marker = "当前" if str(item["run_id"]) == current_run_id else ""
+ rows.append(
+ ""
+ f"| {escape(current_marker)} | "
+ f"{escape(str(item['run_id']))} | "
+ f"{escape(str(item['process_started_at']))} | "
+ f"{escape(str(item['first_seen_at']))} | "
+ f"{escape(str(item['last_seen_at']))} | "
+ f"{_format_int(item['calls'])} | "
+ f"{_format_int(item['prompt_tokens'])} | "
+ f"{_format_rate(item['prompt_cache_hit_rate'])} | "
+ f"{_format_rate(item['theoretical_prompt_cache_hit_rate'])} | "
+ f"{_format_rate(item['avg_common_prefix_rate'])} | "
+ f"{_format_int(item['suspected_context_sliding_calls'])} | "
+ "
"
+ )
+ return "\n".join(rows)
+
+
+def _render_run_comparison_rows(
+ *,
+ current_by_call_site: dict[tuple[str, ...], dict[str, int | str | float]],
+ previous_by_call_site: dict[tuple[str, ...], dict[str, int | str | float]],
+ include_session: bool,
+) -> str:
+ rows: list[str] = []
+ keys = sorted(set(current_by_call_site) | set(previous_by_call_site))
+ for key in keys:
+ current_item = current_by_call_site.get(key, {})
+ previous_item = previous_by_call_site.get(key, {})
+ current_api = float(current_item.get("prompt_cache_hit_rate") or 0.0)
+ previous_api = float(previous_item.get("prompt_cache_hit_rate") or 0.0)
+ current_theory = float(current_item.get("theoretical_prompt_cache_hit_rate") or 0.0)
+ previous_theory = float(previous_item.get("theoretical_prompt_cache_hit_rate") or 0.0)
+ current_prefix = float(current_item.get("avg_common_prefix_rate") or 0.0)
+ previous_prefix = float(previous_item.get("avg_common_prefix_rate") or 0.0)
+ rows.append(
+ ""
+ f"| {escape(key[0])} | "
+ f"{escape(key[1])} | "
+ f"{escape(key[2])} | "
+ + (f"{escape(key[3])} | " if include_session and len(key) > 3 else "")
+ +
+ f"{_format_int(current_item.get('calls', 0))} | "
+ f"{_format_int(previous_item.get('calls', 0))} | "
+ f"{_format_rate(current_api)} | "
+ f"{_format_rate(previous_api)} | "
+ f"{_format_rate(current_api - previous_api)} | "
+ f"{_format_rate(current_theory)} | "
+ f"{_format_rate(previous_theory)} | "
+ f"{_format_rate(current_theory - previous_theory)} | "
+ f"{_format_rate(current_prefix)} | "
+ f"{_format_rate(previous_prefix)} | "
+ f"{_format_rate(current_prefix - previous_prefix)} | "
+ f"{_format_int(current_item.get('suspected_context_sliding_calls', 0))} | "
+ f"{_format_int(previous_item.get('suspected_context_sliding_calls', 0))} | "
+ "
"
+ )
+ return "\n".join(rows)
+
+
+def _format_run_time_label(run_stat: dict[str, int | str | float] | None) -> str:
+ if not run_stat:
+ return ""
+ first_seen_at = str(run_stat.get("first_seen_at") or "").strip()
+ last_seen_at = str(run_stat.get("last_seen_at") or "").strip()
+ process_started_at = str(run_stat.get("process_started_at") or "").strip()
+ if first_seen_at and last_seen_at and first_seen_at != last_seen_at:
+ return f"{first_seen_at} -> {last_seen_at}"
+ if first_seen_at:
+ return first_seen_at
+ return process_started_at
+
+
+def _get_previous_run_stats(
+ run_stats: list[dict[str, int | str | float]],
+ current_run_id: str,
+) -> list[dict[str, int | str | float]]:
+ return [
+ item
+ for item in run_stats
+ if str(item["run_id"]) != current_run_id
+ ]
+
+
+def _render_run_significance_controls(
+ run_stats: list[dict[str, int | str | float]],
+ current_run_id: str,
+) -> str:
+ previous_run_stats = _get_previous_run_stats(run_stats, current_run_id)
+ if not previous_run_stats:
+ return (
+ ""
+ "No previous runs to compare."
+ "
"
+ )
+
+ option_payload = [
+ {
+ "run_id": str(item["run_id"]),
+ "time_label": _format_run_time_label(item),
+ "calls": int(item.get("calls") or 0),
+ }
+ for item in previous_run_stats
+ ]
+ option_json = escape(json.dumps(option_payload, ensure_ascii=False), quote=True)
+ max_index = len(previous_run_stats) - 1
+ return (
+ ""
+ "
"
+ "
"
+ "
"
+ "
"
+ "
"
+ )
+
+
+def _render_run_significance_script() -> str:
+ return """
+
+"""
+
+
+def _build_run_significance_rows(
+ *,
+ usage_events: list[dict[str, Any]],
+ run_stats: list[dict[str, int | str | float]],
+ current_run_id: str,
+ include_session: bool,
+) -> str:
+ current_by_call_site = _aggregate_usage_events_by_call_site(
+ usage_events,
+ run_id=current_run_id,
+ include_session=include_session,
+ )
+ rows: list[str] = []
+ previous_run_stats = _get_previous_run_stats(run_stats, current_run_id)
+ for previous_run_stat in previous_run_stats:
+ previous_run_id = str(previous_run_stat["run_id"])
+ baseline_time = _format_run_time_label(previous_run_stat)
+ previous_by_call_site = _aggregate_usage_events_by_call_site(
+ usage_events,
+ run_id=previous_run_id,
+ include_session=include_session,
+ )
+ keys = sorted(set(current_by_call_site) & set(previous_by_call_site))
+ for key in keys:
+ current_item = current_by_call_site[key]
+ previous_item = previous_by_call_site[key]
+ current_hit = int(current_item.get("prompt_cache_hit_tokens") or 0)
+ current_miss = int(current_item.get("prompt_cache_miss_tokens") or 0)
+ previous_hit = int(previous_item.get("prompt_cache_hit_tokens") or 0)
+ previous_miss = int(previous_item.get("prompt_cache_miss_tokens") or 0)
+ current_total = current_hit + current_miss
+ previous_total = previous_hit + previous_miss
+ current_api = _calculate_rate(current_hit, current_miss)
+ previous_api = _calculate_rate(previous_hit, previous_miss)
+ api_confidence = _calculate_two_proportion_confidence(
+ current_hit=current_hit,
+ current_total=current_total,
+ baseline_hit=previous_hit,
+ baseline_total=previous_total,
+ )
+ current_calls = int(current_item.get("calls") or 0)
+ previous_calls = int(previous_item.get("calls") or 0)
+ current_prefix = float(current_item.get("avg_common_prefix_rate") or 0.0)
+ previous_prefix = float(previous_item.get("avg_common_prefix_rate") or 0.0)
+ prefix_confidence = _calculate_mean_difference_confidence(
+ current_mean=current_prefix,
+ current_variance=float(current_item.get("common_prefix_rate_variance") or 0.0),
+ current_count=current_calls,
+ baseline_mean=previous_prefix,
+ baseline_variance=float(previous_item.get("common_prefix_rate_variance") or 0.0),
+ baseline_count=previous_calls,
+ )
+ rows.append(
+ f""
+ f"| {escape(previous_run_id)} | "
+ f"{escape(baseline_time)} | "
+ f"{escape(key[0])} | "
+ f"{escape(key[1])} | "
+ f"{escape(key[2])} | "
+ + (f"{escape(key[3])} | " if include_session and len(key) > 3 else "")
+ +
+ f"{_format_int(current_calls)} | "
+ f"{_format_int(previous_calls)} | "
+ f"{_format_rate(current_api - previous_api)} | "
+ f"{_format_rate(api_confidence)} | "
+ f"{escape(_format_significance_label(api_confidence))} | "
+ f"{_format_rate(current_prefix - previous_prefix)} | "
+ f"{_format_rate(prefix_confidence)} | "
+ f"{escape(_format_significance_label(prefix_confidence))} | "
+ f"{_format_int(current_item.get('suspected_context_sliding_calls', 0))} | "
+ f"{_format_int(previous_item.get('suspected_context_sliding_calls', 0))} | "
+ "
"
+ )
+
+ if not rows:
+ return (
+ "| 当前 run 还没有可与历史 run 比较的同类调用点,"
+ "或历史数据缺少 run_id。 |
"
+ )
+ return "\n".join(rows)
+
+
+def _render_stat_rows(stats: List[Dict[str, int | str | float]], *, include_session: bool) -> str:
+ rows: list[str] = []
+ for item in stats:
+ rows.append(
+ ""
+ f"| {escape(str(item['task_name']))} | "
+ f"{escape(str(item['request_type']))} | "
+ f"{escape(str(item['model_name']))} | "
+ + (f"{escape(str(item.get('session_id', '')))} | " if include_session else "")
+ +
+ f"{_format_rate(item['prompt_cache_hit_rate'])} | "
+ f"{_format_rate(item['theoretical_prompt_cache_hit_rate'])} | "
+ f"{_format_rate(item['prompt_cache_hit_rate_delta'])} | "
+ f"{_format_int(item['prompt_cache_hit_tokens'])} | "
+ f"{_format_int(item['prompt_cache_miss_tokens'])} | "
+ f"{_format_int(item['theoretical_prompt_cache_hit_tokens'])} | "
+ f"{_format_int(item['theoretical_prompt_cache_miss_tokens'])} | "
+ f"{_format_int(item['prompt_tokens'])} | "
+ f"{_format_int(item['calls'])} | "
+ f"{_format_int(item['cache_reported_calls'])} | "
+ f"{_format_int(item['theoretical_compared_calls'])} | "
+ f"{_format_int(item['theoretical_cache_pool_hits'])} | "
+ f"{_format_rate(item['avg_common_prefix_rate'])} | "
+ f"{_format_int(item['suspected_context_sliding_calls'])} | "
+ f"{item['avg_sliding_dropped_messages']} | "
+ f"{item['avg_sliding_aligned_messages']} | "
+ f"{escape(str(item.get('top_dynamic_diff_paths', '')))} | "
+ "
"
+ )
+ return "\n".join(rows)
+
+
+def _aggregate_stats_snapshot(
+ stats_snapshot: List[Dict[str, int | str | float]],
+ *,
+ include_session: bool,
+) -> List[Dict[str, int | str | float]]:
+ grouped: dict[tuple[str, ...], dict[str, int | str | float]] = {}
+ for item in stats_snapshot:
+ base_key = (
+ str(item.get("task_name") or ""),
+ str(item.get("request_type") or ""),
+ str(item.get("model_name") or ""),
+ )
+ key = (*base_key, str(item.get("session_id") or "")) if include_session else base_key
+ target = grouped.setdefault(
+ key,
+ {
+ "task_name": base_key[0],
+ "request_type": base_key[1],
+ "model_name": base_key[2],
+ "session_id": str(item.get("session_id") or "") if include_session else "",
+ "calls": 0,
+ "cache_reported_calls": 0,
+ "prompt_tokens": 0,
+ "prompt_cache_hit_tokens": 0,
+ "prompt_cache_miss_tokens": 0,
+ "theoretical_prompt_cache_hit_tokens": 0,
+ "theoretical_prompt_cache_miss_tokens": 0,
+ "theoretical_compared_calls": 0,
+ "theoretical_cache_pool_hits": 0,
+ "common_prefix_rate_weighted_total": 0.0,
+ "suspected_context_sliding_calls": 0,
+ "sliding_dropped_weighted_total": 0.0,
+ "sliding_aligned_weighted_total": 0.0,
+ "top_dynamic_diff_paths": "",
+ },
+ )
+ calls = int(item.get("calls") or 0)
+ sliding_calls = int(item.get("suspected_context_sliding_calls") or 0)
+ target["calls"] = int(target["calls"]) + calls
+ target["cache_reported_calls"] = int(target["cache_reported_calls"]) + int(item.get("cache_reported_calls") or 0)
+ target["prompt_tokens"] = int(target["prompt_tokens"]) + int(item.get("prompt_tokens") or 0)
+ target["prompt_cache_hit_tokens"] = int(target["prompt_cache_hit_tokens"]) + int(item.get("prompt_cache_hit_tokens") or 0)
+ target["prompt_cache_miss_tokens"] = int(target["prompt_cache_miss_tokens"]) + int(item.get("prompt_cache_miss_tokens") or 0)
+ target["theoretical_prompt_cache_hit_tokens"] = int(target["theoretical_prompt_cache_hit_tokens"]) + int(
+ item.get("theoretical_prompt_cache_hit_tokens") or 0
+ )
+ target["theoretical_prompt_cache_miss_tokens"] = int(target["theoretical_prompt_cache_miss_tokens"]) + int(
+ item.get("theoretical_prompt_cache_miss_tokens") or 0
+ )
+ target["theoretical_compared_calls"] = int(target["theoretical_compared_calls"]) + int(
+ item.get("theoretical_compared_calls") or 0
+ )
+ target["theoretical_cache_pool_hits"] = int(target["theoretical_cache_pool_hits"]) + int(
+ item.get("theoretical_cache_pool_hits") or 0
+ )
+ target["common_prefix_rate_weighted_total"] = float(target["common_prefix_rate_weighted_total"]) + (
+ float(item.get("avg_common_prefix_rate") or 0.0) * calls
+ )
+ target["suspected_context_sliding_calls"] = int(target["suspected_context_sliding_calls"]) + sliding_calls
+ target["sliding_dropped_weighted_total"] = float(target["sliding_dropped_weighted_total"]) + (
+ float(item.get("avg_sliding_dropped_messages") or 0.0) * sliding_calls
+ )
+ target["sliding_aligned_weighted_total"] = float(target["sliding_aligned_weighted_total"]) + (
+ float(item.get("avg_sliding_aligned_messages") or 0.0) * sliding_calls
+ )
+ if include_session:
+ target["top_dynamic_diff_paths"] = item.get("top_dynamic_diff_paths", "")
+
+ result: list[dict[str, int | str | float]] = []
+ for item in grouped.values():
+ calls = int(item["calls"])
+ sliding_calls = int(item["suspected_context_sliding_calls"])
+ hit_tokens = int(item["prompt_cache_hit_tokens"])
+ miss_tokens = int(item["prompt_cache_miss_tokens"])
+ theoretical_hit_tokens = int(item["theoretical_prompt_cache_hit_tokens"])
+ theoretical_miss_tokens = int(item["theoretical_prompt_cache_miss_tokens"])
+ item["prompt_cache_hit_rate"] = round(_calculate_rate(hit_tokens, miss_tokens), 2)
+ item["theoretical_prompt_cache_hit_rate"] = round(
+ _calculate_rate(theoretical_hit_tokens, theoretical_miss_tokens),
+ 2,
+ )
+ item["prompt_cache_hit_rate_delta"] = round(
+ float(item["prompt_cache_hit_rate"]) - float(item["theoretical_prompt_cache_hit_rate"]),
+ 2,
+ )
+ item["avg_common_prefix_rate"] = (
+ round(float(item["common_prefix_rate_weighted_total"]) / calls, 2) if calls else 0.0
+ )
+ item["avg_sliding_dropped_messages"] = (
+ round(float(item["sliding_dropped_weighted_total"]) / sliding_calls, 2) if sliding_calls else 0.0
+ )
+ item["avg_sliding_aligned_messages"] = (
+ round(float(item["sliding_aligned_weighted_total"]) / sliding_calls, 2) if sliding_calls else 0.0
+ )
+ result.append(item)
+ return result
+
+
+def _render_html_report(stats_snapshot: List[Dict[str, int | str | float]], *, include_session: bool = False) -> str:
+ updated_at = datetime.now().isoformat(timespec="seconds")
+ visible_stats_snapshot = _aggregate_stats_snapshot(stats_snapshot, include_session=include_session)
+ usage_events = _read_usage_events()
+ run_stats = _aggregate_usage_events_by_run(usage_events)
+ current_run_id = _store.run_id
+ previous_run_id = _get_previous_run_id(run_stats, current_run_id)
+ current_by_call_site = _aggregate_usage_events_by_call_site(
+ usage_events,
+ run_id=current_run_id,
+ include_session=include_session,
+ )
+ previous_by_call_site = (
+ _aggregate_usage_events_by_call_site(
+ usage_events,
+ run_id=previous_run_id,
+ include_session=include_session,
+ ) if previous_run_id else {}
+ )
+ sorted_by_rate = sorted(
+ visible_stats_snapshot,
+ key=lambda item: (
+ float(item["prompt_cache_hit_rate"]),
+ -int(item["prompt_cache_miss_tokens"]),
+ ),
+ )
+ low_stats = sorted_by_rate[:SUMMARY_LIMIT]
+ high_stats = list(reversed(sorted_by_rate[-SUMMARY_LIMIT:]))
+ all_stats = sorted(
+ visible_stats_snapshot,
+ key=lambda item: (
+ str(item["task_name"]),
+ str(item["request_type"]),
+ str(item["model_name"]),
+ ),
+ )
+ total_calls = sum(int(item["calls"]) for item in visible_stats_snapshot)
+ total_prompt_tokens = sum(int(item["prompt_tokens"]) for item in visible_stats_snapshot)
+ total_hit_tokens = sum(int(item["prompt_cache_hit_tokens"]) for item in visible_stats_snapshot)
+ total_theoretical_hit_tokens = sum(int(item["theoretical_prompt_cache_hit_tokens"]) for item in visible_stats_snapshot)
+ total_miss_tokens = sum(int(item["prompt_cache_miss_tokens"]) for item in visible_stats_snapshot)
+ total_theoretical_miss_tokens = sum(int(item["theoretical_prompt_cache_miss_tokens"]) for item in visible_stats_snapshot)
+ total_cache_tokens = total_hit_tokens + total_miss_tokens
+ total_theoretical_cache_tokens = total_theoretical_hit_tokens + total_theoretical_miss_tokens
+ overall_hit_rate = total_hit_tokens / total_cache_tokens * 100 if total_cache_tokens > 0 else 0.0
+ overall_theoretical_hit_rate = (
+ total_theoretical_hit_tokens / total_theoretical_cache_tokens * 100
+ if total_theoretical_cache_tokens > 0
+ else 0.0
+ )
+ session_head = "Session | " if include_session else ""
+ report_title = "LLM Prompt Cache Stats By Session" if include_session else "LLM Prompt Cache Stats"
+ peer_report_link = (
+ f"Overview report"
+ if include_session
+ else f"Session detail report"
+ )
+ table_head = (
+ f"| Task | Request | Model | {session_head}API hit | Theory hit | "
+ "Delta | API hit tok | API miss tok | Theory hit tok | Theory miss tok | "
+ "Prompt tok | Calls | Reported | Compared | Pool hits | "
+ "Avg prefix | Sliding calls | Avg dropped msg | Avg aligned msg | "
+ "Top dynamic diff paths |
"
+ )
+ run_table_head = (
+ " | Run ID | Process started | First event | Last event | "
+ "Calls | Prompt tok | API hit | Theory hit | Avg prefix | "
+ "Sliding calls |
"
+ )
+ run_compare_head = (
+ f"| Task | Request | Model | {session_head}Current calls | Previous calls | "
+ "Current API | Previous API | API delta | "
+ "Current Theory | Previous Theory | Theory delta | "
+ "Current Prefix | Previous Prefix | Prefix delta | "
+ "Current Sliding | Previous Sliding |
"
+ )
+ run_significance_head = (
+ f"| Baseline run | Baseline time | Task | Request | Model | {session_head}"
+ "Current calls | Baseline calls | "
+ "API delta | API confidence | API significant | "
+ "Prefix delta | Prefix confidence | Prefix significant | "
+ "Current sliding | Baseline sliding |
"
+ )
+
+ return f"""
+
+
+
+ {escape(report_title)}
+
+
+
+ {escape(report_title)}
+ Updated at: {escape(updated_at)}. Current run: {escape(current_run_id)}. Process started at: {escape(_store.process_started_at)}. Grouped by task_name / request_type / model_name{escape(' / session_id' if include_session else '')}. Local prompt pool size: {PROMPT_CACHE_POOL_SIZE}. {peer_report_link}
+
+
Calls
{_format_int(total_calls)}
+
Prompt tokens
{_format_int(total_prompt_tokens)}
+
API hit tokens
{_format_int(total_hit_tokens)}
+
API hit rate
{_format_rate(overall_hit_rate)}
+
Theory hit tokens
{_format_int(total_theoretical_hit_tokens)}
+
Theory hit rate
{_format_rate(overall_theoretical_hit_rate)}
+
+ Run Comparison
+
+ {run_table_head}
+ {_render_run_rows(run_stats, current_run_id)}
+
+ Current vs Previous Run By Call Site
+
+ {run_compare_head}
+ {_render_run_comparison_rows(current_by_call_site=current_by_call_site, previous_by_call_site=previous_by_call_site, include_session=include_session)}
+
+ Current vs Every Previous Run Significance
+ {_render_run_significance_controls(run_stats, current_run_id)}
+
+ {run_significance_head}
+ {_build_run_significance_rows(usage_events=usage_events, run_stats=run_stats, current_run_id=current_run_id, include_session=include_session)}
+
+ Low API Hit Rate
+
+ {table_head}
+ {_render_stat_rows(low_stats, include_session=include_session)}
+
+ High API Hit Rate
+
+ {table_head}
+ {_render_stat_rows(high_stats, include_session=include_session)}
+
+ All Call Sites
+
+ {table_head}
+ {_render_stat_rows(all_stats, include_session=include_session)}
+
+ {_render_run_significance_script()}
+
+
+"""
+
+
+def _write_html_report(stats_snapshot: List[Dict[str, int | str | float]]) -> None:
+ CACHE_STATS_DIR.mkdir(parents=True, exist_ok=True)
+ _get_report_path().write_text(_render_html_report(stats_snapshot, include_session=False), encoding="utf-8")
+ _get_session_report_path().write_text(_render_html_report(stats_snapshot, include_session=True), encoding="utf-8")
+
+
+def _write_usage_event(event: Dict[str, int | str | float | bool]) -> None:
+ try:
+ _write_json_line(_get_usage_log_path(datetime.now()), event)
+ except Exception as exc:
+ logger.warning(f"写入 LLM prompt cache 明细失败: {exc}")
+
+
+def _write_report(stats_snapshot: List[Dict[str, int | str | float]]) -> None:
+ try:
+ _write_html_report(stats_snapshot)
+ except Exception as exc:
+ logger.warning(f"写入 LLM prompt cache HTML 报告失败: {exc}")
+
+
+def record_llm_cache_usage(
+ *,
+ task_name: str,
+ request_type: str,
+ model_name: str,
+ session_id: str = "",
+ prompt_tokens: int,
+ prompt_cache_hit_tokens: int,
+ prompt_cache_miss_tokens: int,
+ prompt_text: str | None = None,
+) -> None:
+ """Record one LLM prompt cache usage event."""
+
+ if not _is_llm_cache_stats_enabled():
+ return
+
+ normalized_task_name = str(task_name or "").strip()
+ if normalized_task_name not in FOCUSED_TASK_NAMES:
+ return
+
+ normalized_request_type = _normalize_request_type(request_type)
+ if normalized_request_type in EXCLUDED_REQUEST_TYPES:
+ return
+
+ normalized_model_name = _normalize_model_name(model_name)
+ normalized_session_id = _normalize_session_id(session_id)
+ normalized_prompt_tokens = max(int(prompt_tokens or 0), 0)
+ hit_tokens, miss_tokens, has_cache_report = _normalize_cache_tokens(
+ prompt_tokens=normalized_prompt_tokens,
+ prompt_cache_hit_tokens=prompt_cache_hit_tokens,
+ prompt_cache_miss_tokens=prompt_cache_miss_tokens,
+ )
+
+ with _store.lock:
+ key = (normalized_task_name, normalized_request_type, normalized_model_name, normalized_session_id)
+ prompt_pool = _store.prompt_pools.get(key, [])
+ cache_match = _calculate_theoretical_cache_match(
+ prompt_tokens=normalized_prompt_tokens,
+ prompt_text=prompt_text,
+ prompt_pool=prompt_pool,
+ )
+ dynamic_diff = _diagnose_dynamic_diff(cache_match.best_prompt_text, prompt_text)
+ prompt_diagnostics = _diagnose_prompt_cache_details(
+ previous_prompt_text=cache_match.best_prompt_text,
+ current_prompt_text=prompt_text,
+ common_prefix_chars=cache_match.common_prefix_chars,
+ )
+ if prompt_text:
+ next_prompt_pool = [*prompt_pool, prompt_text]
+ if len(next_prompt_pool) > PROMPT_CACHE_POOL_SIZE:
+ next_prompt_pool = next_prompt_pool[-PROMPT_CACHE_POOL_SIZE:]
+ _store.prompt_pools[key] = next_prompt_pool
+
+ stat = _store.stats.get(key)
+ if stat is None:
+ stat = LLMCacheStat(
+ task_name=normalized_task_name,
+ request_type=normalized_request_type,
+ model_name=normalized_model_name,
+ session_id=normalized_session_id,
+ )
+ _store.stats[key] = stat
+
+ stat.calls += 1
+ stat.prompt_tokens += normalized_prompt_tokens
+ stat.prompt_cache_hit_tokens += hit_tokens
+ stat.prompt_cache_miss_tokens += miss_tokens
+ stat.theoretical_prompt_cache_hit_tokens += cache_match.hit_tokens
+ stat.theoretical_prompt_cache_miss_tokens += cache_match.miss_tokens
+ stat.common_prefix_rate_total += prompt_diagnostics.common_prefix_rate
+ if prompt_diagnostics.suspected_context_sliding:
+ stat.suspected_context_sliding_calls += 1
+ stat.sliding_dropped_messages_total += prompt_diagnostics.sliding_dropped_head_messages
+ stat.sliding_aligned_messages_total += prompt_diagnostics.sliding_aligned_messages
+ stat.dynamic_diff_counts[dynamic_diff.path] = stat.dynamic_diff_counts.get(dynamic_diff.path, 0) + 1
+ if has_cache_report:
+ stat.cache_reported_calls += 1
+ if cache_match.compared:
+ stat.theoretical_compared_calls += 1
+ if cache_match.hit_tokens > 0:
+ stat.theoretical_cache_pool_hits += 1
+ _store.total_calls += 1
+ _store.calls_since_report += 1
+ _store.calls_in_run += 1
+
+ api_hit_rate = hit_tokens / (hit_tokens + miss_tokens) * 100 if hit_tokens + miss_tokens > 0 else 0.0
+ event = {
+ "created_at": datetime.now().isoformat(timespec="seconds"),
+ "run_id": _store.run_id,
+ "process_started_at": _store.process_started_at,
+ "call_index_in_run": _store.calls_in_run,
+ "task_name": normalized_task_name,
+ "request_type": normalized_request_type,
+ "model_name": normalized_model_name,
+ "session_id": normalized_session_id,
+ "prompt_tokens": normalized_prompt_tokens,
+ "prompt_chars": len(prompt_text or ""),
+ "prompt_cache_hit_tokens": hit_tokens,
+ "prompt_cache_miss_tokens": miss_tokens,
+ "prompt_cache_hit_rate": round(api_hit_rate, 2),
+ "theoretical_prompt_cache_hit_tokens": cache_match.hit_tokens,
+ "theoretical_prompt_cache_miss_tokens": cache_match.miss_tokens,
+ "theoretical_prompt_cache_hit_rate": round(cache_match.hit_rate, 2),
+ "theoretical_cache_pool_size": cache_match.pool_size,
+ "theoretical_best_match_rank": cache_match.best_match_rank,
+ "theoretical_common_prefix_chars": cache_match.common_prefix_chars,
+ "theoretical_common_prefix_rate": round(prompt_diagnostics.common_prefix_rate, 2),
+ "current_message_count": prompt_diagnostics.current_message_count,
+ "best_match_message_count": prompt_diagnostics.best_match_message_count,
+ "common_prefix_messages": prompt_diagnostics.common_prefix_messages,
+ "common_suffix_messages": prompt_diagnostics.common_suffix_messages,
+ "prompt_growth_chars": prompt_diagnostics.prompt_growth_chars,
+ "longest_aligned_message_overlap": prompt_diagnostics.longest_aligned_message_overlap,
+ "aligned_previous_start_index": prompt_diagnostics.aligned_previous_start_index,
+ "aligned_current_start_index": prompt_diagnostics.aligned_current_start_index,
+ "suspected_context_sliding": prompt_diagnostics.suspected_context_sliding,
+ "sliding_dropped_head_messages": prompt_diagnostics.sliding_dropped_head_messages,
+ "sliding_aligned_messages": prompt_diagnostics.sliding_aligned_messages,
+ "sliding_new_tail_messages": prompt_diagnostics.sliding_new_tail_messages,
+ "current_first_message_role": prompt_diagnostics.current_first_message_role,
+ "best_first_message_role": prompt_diagnostics.best_first_message_role,
+ "current_last_message_role": prompt_diagnostics.current_last_message_role,
+ "best_last_message_role": prompt_diagnostics.best_last_message_role,
+ "prompt_cache_hit_rate_delta": round(api_hit_rate - cache_match.hit_rate, 2),
+ "dynamic_diff_path": dynamic_diff.path,
+ "dynamic_diff_previous": dynamic_diff.previous_value,
+ "dynamic_diff_current": dynamic_diff.current_value,
+ "cache_reported": has_cache_report,
+ "theoretical_compared": cache_match.compared,
+ }
+ stats_snapshot = [stat.to_dict() for stat in _store.stats.values()]
+
+ now = time.time()
+ should_update_report = (
+ _store.last_report_at <= 0
+ or _store.calls_since_report >= REPORT_INTERVAL_CALLS
+ or now - _store.last_report_at >= REPORT_INTERVAL_SECONDS
+ )
+ if should_update_report:
+ _store.last_report_at = now
+ _store.calls_since_report = 0
+ stats_snapshot_to_report = stats_snapshot
+ else:
+ stats_snapshot_to_report = []
+
+ _write_usage_event(event)
+ if stats_snapshot_to_report:
+ _write_report(stats_snapshot_to_report)
+ log_llm_cache_stats_summary(stats_snapshot_to_report)
+
+
+def get_llm_cache_stats_snapshot() -> List[Dict[str, int | str | float]]:
+ """Return current in-process LLM prompt cache stats."""
+
+ with _store.lock:
+ return [stat.to_dict() for stat in _store.stats.values()]
+
+
+def reset_llm_cache_stats() -> None:
+ """Reset in-process stats. Intended for tests and local debugging."""
+
+ with _store.lock:
+ _store.stats.clear()
+ _store.prompt_pools.clear()
+ _store.total_calls = 0
+ _store.calls_in_run = 0
+ _store.last_report_at = 0
+ _store.calls_since_report = 0
+
+
+def log_llm_cache_stats_summary(stats_snapshot: List[Dict[str, int | str | float]] | None = None) -> None:
+ """Log current highest and lowest prompt cache hit-rate call sites."""
+
+ snapshot = stats_snapshot or get_llm_cache_stats_snapshot()
+ if not snapshot:
+ return
+
+ sorted_stats = sorted(
+ snapshot,
+ key=lambda item: (
+ float(item["prompt_cache_hit_rate"]),
+ -int(item["prompt_cache_miss_tokens"]),
+ ),
+ )
+ low_stats = sorted_stats[:SUMMARY_LIMIT]
+ high_stats = list(reversed(sorted_stats[-SUMMARY_LIMIT:]))
+
+ def _format_stat(item: Dict[str, int | str | float]) -> str:
+ return (
+ f"{item['task_name']}/{item['request_type']}/{item['model_name']}: "
+ f"api_hit_rate={float(item['prompt_cache_hit_rate']):.2f}%, "
+ f"theory_hit_rate={float(item['theoretical_prompt_cache_hit_rate']):.2f}%, "
+ f"delta={float(item['prompt_cache_hit_rate_delta']):.2f}%, "
+ f"avg_prefix={float(item['avg_common_prefix_rate']):.2f}%, "
+ f"sliding_calls={item['suspected_context_sliding_calls']}, "
+ f"top_dynamic={item.get('top_dynamic_diff_paths', '')}, "
+ f"hit={item['prompt_cache_hit_tokens']}, "
+ f"miss={item['prompt_cache_miss_tokens']}, "
+ f"prompt={item['prompt_tokens']}, "
+ f"calls={item['calls']}, "
+ f"reported={item['cache_reported_calls']}"
+ )
+
+ logger.info(
+ "LLM prompt cache 统计摘要\n"
+ "低命中调用点:\n- " + "\n- ".join(_format_stat(item) for item in low_stats) + "\n"
+ "高命中调用点:\n- " + "\n- ".join(_format_stat(item) for item in high_stats)
+ )
diff --git a/src/services/llm_service.py b/src/services/llm_service.py
index 264d2dd2..92da545f 100644
--- a/src/services/llm_service.py
+++ b/src/services/llm_service.py
@@ -6,6 +6,8 @@
from typing import Any, Dict, List, Tuple
+import hashlib
+import inspect
import json
from src.common.data_models.embedding_service_data_models import EmbeddingResult
@@ -26,6 +28,7 @@ from src.llm_models.payload_content.message import Message, MessageBuilder, Role
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMOrchestrator
from src.services.embedding_service import EmbeddingServiceClient
+from src.services.llm_cache_stats import record_llm_cache_usage
from src.services.service_task_resolver import (
get_available_models as _get_available_models,
resolve_task_name as _resolve_task_name,
@@ -46,7 +49,7 @@ class LLMServiceClient:
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
"""
- def __init__(self, task_name: str, request_type: str = "") -> None:
+ def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
"""初始化 LLM 服务门面。
Args:
@@ -55,6 +58,7 @@ class LLMServiceClient:
"""
self.task_name = _resolve_task_name(task_name)
self.request_type = request_type
+ self.session_id = str(session_id or "").strip()
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
@staticmethod
@@ -85,6 +89,70 @@ class LLMServiceClient:
return LLMImageOptions()
return options
+ @staticmethod
+ def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
+ parts: list[dict[str, Any]] = []
+ for part in message.parts:
+ if hasattr(part, "text"):
+ parts.append({"type": "text", "text": part.text})
+ continue
+
+ image_base64 = getattr(part, "image_base64", "")
+ image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
+ parts.append(
+ {
+ "type": "image",
+ "format": getattr(part, "image_format", ""),
+ "size": len(image_base64),
+ "sha256": image_digest,
+ }
+ )
+
+ return {
+ "role": str(message.role.value if hasattr(message.role, "value") else message.role),
+ "parts": parts,
+ "tool_call_id": message.tool_call_id,
+ "tool_name": message.tool_name,
+ "tool_calls": [
+ {
+ "id": tool_call.call_id,
+ "name": tool_call.func_name,
+ "arguments": tool_call.args,
+ "extra_content": tool_call.extra_content,
+ }
+ for tool_call in (message.tool_calls or [])
+ ],
+ }
+
+ @classmethod
+ def _build_cache_stats_prompt_text(
+ cls,
+ *,
+ messages: List[Message],
+ tool_options: Any,
+ response_format: Any,
+ ) -> str:
+ payload = {
+ "messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
+ "tool_options": tool_options or [],
+ "response_format": response_format,
+ }
+ return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
+
+ def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
+ """记录当前调用的 prompt cache 统计。"""
+
+ record_llm_cache_usage(
+ task_name=self.task_name,
+ request_type=self.request_type,
+ model_name=result.model_name,
+ session_id=self.session_id,
+ prompt_tokens=result.prompt_tokens,
+ prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
+ prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
+ prompt_text=prompt_text,
+ )
+
async def generate_response(
self,
prompt: str,
@@ -100,7 +168,12 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
- return await self._orchestrator.generate_response_async(
+ prompt_text = self._build_cache_stats_prompt_text(
+ messages=[MessageBuilder().add_text_content(prompt).build()],
+ tool_options=active_options.tool_options,
+ response_format=active_options.response_format,
+ )
+ result = await self._orchestrator.generate_response_async(
prompt=prompt,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
@@ -109,6 +182,8 @@ class LLMServiceClient:
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
+ self._record_cache_stats(result, prompt_text=prompt_text)
+ return result
async def generate_response_with_messages(
self,
@@ -125,8 +200,22 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
- return await self._orchestrator.generate_response_with_message_async(
- message_factory=message_factory,
+ prompt_text_holder: dict[str, str] = {}
+
+ def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
+ if len(inspect.signature(message_factory).parameters) >= 2:
+ messages = message_factory(client, model_info)
+ else:
+ messages = message_factory(client)
+ prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
+ messages=messages,
+ tool_options=active_options.tool_options,
+ response_format=active_options.response_format,
+ )
+ return messages
+
+ result = await self._orchestrator.generate_response_with_message_async(
+ message_factory=cache_stats_message_factory,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
tools=active_options.tool_options,
@@ -134,6 +223,8 @@ class LLMServiceClient:
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
+ self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
+ return result
async def generate_response_for_image(
self,
@@ -154,7 +245,30 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_image_options(options)
- return await self._orchestrator.generate_response_for_image(
+ image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
+ prompt_text = json.dumps(
+ {
+ "messages": [
+ {
+ "role": "user",
+ "parts": [
+ {"type": "text", "text": prompt},
+ {
+ "type": "image",
+ "format": image_format,
+ "size": len(image_base64),
+ "sha256": image_digest,
+ },
+ ],
+ }
+ ],
+ "tool_options": [],
+ "response_format": None,
+ },
+ ensure_ascii=False,
+ sort_keys=True,
+ )
+ result = await self._orchestrator.generate_response_for_image(
prompt=prompt,
image_base64=image_base64,
image_format=image_format,
@@ -162,6 +276,8 @@ class LLMServiceClient:
max_tokens=active_options.max_tokens,
interrupt_flag=active_options.interrupt_flag,
)
+ self._record_cache_stats(result, prompt_text=prompt_text)
+ return result
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
"""执行音频转写请求。
diff --git a/src/webui/config_schema.py b/src/webui/config_schema.py
index 862da1e5..1f11faa2 100644
--- a/src/webui/config_schema.py
+++ b/src/webui/config_schema.py
@@ -70,11 +70,15 @@ class ConfigSchemaGenerator:
) -> Dict[str, Any]:
field_docs = config_class.get_class_field_docs()
field_type = cls._map_field_type(annotation)
+ raw_description = field_docs.get(field_name, field_info.description or "")
+ # `_wrap_` 标记在配置类 docstring 中表示该说明应作为块级注释(独立成行)
+ # 在前端展示时把它转为换行符,使描述以新行起始或在中间换行
+ description = raw_description.replace("_wrap_", "\n").strip("\n")
schema: Dict[str, Any] = {
"name": field_name,
"type": field_type,
"label": field_name,
- "description": field_docs.get(field_name, field_info.description or ""),
+ "description": description,
"required": field_info.is_required(),
}
diff --git a/src/webui/routers/chat/routes.py b/src/webui/routers/chat/routes.py
index 40988391..3c03227a 100644
--- a/src/webui/routers/chat/routes.py
+++ b/src/webui/routers/chat/routes.py
@@ -13,10 +13,10 @@ from src.config.config import global_config
from src.webui.dependencies import require_auth
from .service import (
- WEBUI_CHAT_GROUP_ID,
WEBUI_CHAT_PLATFORM,
chat_history,
chat_manager,
+ normalize_webui_user_id,
)
logger = get_logger("webui.chat")
@@ -30,10 +30,15 @@ async def get_chat_history(
user_id: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
) -> Dict[str, object]:
- """获取聊天历史记录。"""
- del user_id
- target_group_id = group_id or WEBUI_CHAT_GROUP_ID
- history = chat_history.get_history(limit, target_group_id)
+ """获取聊天历史记录。
+
+ 优先按 ``group_id`` 加载虚拟群聊历史;未提供时使用规范化后的 ``user_id`` 加载 WebUI 私聊历史。
+ """
+ if group_id:
+ history = chat_history.get_history(limit, group_id=group_id)
+ else:
+ normalized_user_id = normalize_webui_user_id(user_id)
+ history = chat_history.get_history(limit, user_id=normalized_user_id)
return {"success": True, "messages": history, "total": len(history)}
@@ -100,10 +105,18 @@ async def get_persons_by_platform(
@router.delete("/history")
async def clear_chat_history(
+ user_id: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None),
) -> Dict[str, object]:
- """清空聊天历史记录。"""
- deleted = chat_history.clear_history(group_id)
+ """清空聊天历史记录。
+
+ 优先按 ``group_id`` 清理虚拟群聊历史;未提供时使用规范化后的 ``user_id`` 清理 WebUI 私聊历史。
+ """
+ if group_id:
+ deleted = chat_history.clear_history(group_id=group_id)
+ else:
+ normalized_user_id = normalize_webui_user_id(user_id)
+ deleted = chat_history.clear_history(user_id=normalized_user_id)
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
@@ -113,6 +126,5 @@ async def get_chat_info() -> Dict[str, object]:
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
- "group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
diff --git a/src/webui/routers/chat/service.py b/src/webui/routers/chat/service.py
index b8433b92..168d3190 100644
--- a/src/webui/routers/chat/service.py
+++ b/src/webui/routers/chat/service.py
@@ -18,6 +18,8 @@ from src.common.message_repository import find_messages
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
+from .serializers import serialize_message_sequence
+
logger = get_logger("webui.chat")
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
@@ -61,7 +63,7 @@ class ChatSessionConnection:
client_session_id: str
user_id: str
user_name: str
- active_group_id: str
+ channel_key: str
virtual_config: Optional[VirtualIdentityConfig]
sender: AsyncMessageSender
@@ -92,6 +94,21 @@ class ChatHistoryManager:
user_id = user_info.user_id or ""
is_bot = is_bot_self(msg.platform, user_id)
+ # 将存库中的 raw_message 序列化为前端可识别的富文本消息段,
+ # 避免“刚刚收到的机器人回复是富文本,刷新后变成纯文本”的体验不一致。
+ segments: List[Dict[str, Any]] = []
+ try:
+ raw_message = getattr(msg, "raw_message", None)
+ if raw_message is not None and getattr(raw_message, "components", None):
+ segments = serialize_message_sequence(raw_message)
+ except Exception as exc: # 仅记录警告,退化为纯文本
+ logger.debug(f"序列化历史消息段失败,退化为纯文本: {exc}")
+ segments = []
+
+ is_rich = bool(segments) and not (
+ len(segments) == 1 and segments[0].get("type") == "text"
+ )
+
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
@@ -100,32 +117,119 @@ class ChatHistoryManager:
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
+ "message_type": "rich" if is_rich else "text",
+ "segments": segments if is_rich else None,
}
- def _resolve_session_id(self, group_id: Optional[str]) -> str:
- """根据群组标识解析聊天会话 ID。
+ def _enrich_reply_segments(
+ self,
+ segments: List[Dict[str, Any]],
+ message_index: Dict[str, SessionMessage],
+ session_id: Optional[str],
+ ) -> None:
+ """回填历史消息中 reply 段缺失的发送者/原内容字段。
+
+ DB 中持久化的 ReplyComponent 通常只保留了 ``target_message_id``,
+ ``target_message_content`` / ``target_message_sender_*`` 字段为空。
+ 这里基于当前会话已加载的消息列表(必要时回查数据库)进行补全。
Args:
- group_id: 群组标识。
+ segments: 单条历史消息的消息段列表,原地修改。
+ message_index: 当前会话已加载消息的 ``message_id -> SessionMessage`` 索引。
+ session_id: 当前会话 ID,用于按 ID 单查时缩小范围。
+ """
+ for segment in segments:
+ if not isinstance(segment, dict) or segment.get("type") != "reply":
+ continue
+ data = segment.get("data")
+ if not isinstance(data, dict):
+ continue
+ target_message_id = data.get("target_message_id")
+ if not target_message_id:
+ continue
+
+ has_content = bool(str(data.get("target_message_content") or "").strip())
+ has_sender = any(
+ str(data.get(key) or "").strip()
+ for key in (
+ "target_message_sender_id",
+ "target_message_sender_nickname",
+ "target_message_sender_cardname",
+ )
+ )
+ if has_content and has_sender:
+ continue
+
+ target_msg = message_index.get(str(target_message_id))
+ if target_msg is None:
+ # 退化为按 ID 单查(仅当不在当前窗口内时才付出 DB 代价)
+ try:
+ from src.services.message_service import get_message_by_id
+
+ target_msg = get_message_by_id(str(target_message_id), session_id or None)
+ except Exception as exc:
+ logger.debug(f"按 ID 回查 reply 目标消息失败: {exc}")
+ target_msg = None
+ if target_msg is None:
+ continue
+
+ user_info = target_msg.message_info.user_info
+ if not has_content:
+ content_text = (
+ target_msg.processed_plain_text
+ or target_msg.display_message
+ or ""
+ )
+ data["target_message_content"] = content_text
+ if not has_sender:
+ data["target_message_sender_id"] = user_info.user_id or ""
+ data["target_message_sender_nickname"] = user_info.user_nickname or ""
+ data["target_message_sender_cardname"] = (
+ getattr(user_info, "user_cardname", "") or ""
+ )
+
+ def _resolve_session_id(
+ self,
+ group_id: Optional[str] = None,
+ user_id: Optional[str] = None,
+ ) -> Optional[str]:
+ """根据会话标识解析内部聊天会话 ID。
+
+ 优先按虚拟群聊解析;否则按 WebUI 私聊解析。
+
+ Args:
+ group_id: 群组标识(虚拟群聊模式)。
+ user_id: 用户标识(私聊模式)。
Returns:
- str: 内部聊天会话 ID。
+ Optional[str]: 内部聊天会话 ID;当 group_id 与 user_id 均未提供时返回 ``None``。
"""
- target_group_id = group_id or WEBUI_CHAT_GROUP_ID
- return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
+ if group_id:
+ return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=group_id)
+ if user_id:
+ return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, user_id=user_id)
+ return None
- def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
+ def get_history(
+ self,
+ limit: int = 50,
+ group_id: Optional[str] = None,
+ user_id: Optional[str] = None,
+ ) -> List[Dict[str, Any]]:
"""获取指定会话的历史消息。
Args:
limit: 最大返回条数。
- group_id: 群组标识。
+ group_id: 群组标识(虚拟群聊模式)。
+ user_id: 用户标识(私聊模式)。
Returns:
List[Dict[str, Any]]: 历史消息列表。
"""
- target_group_id = group_id or WEBUI_CHAT_GROUP_ID
- session_id = self._resolve_session_id(target_group_id)
+ session_id = self._resolve_session_id(group_id=group_id, user_id=user_id)
+ if session_id is None:
+ logger.debug("获取聊天历史时缺少 group_id 与 user_id,返回空列表")
+ return []
try:
messages = find_messages(
session_id=session_id,
@@ -133,30 +237,54 @@ class ChatHistoryManager:
limit_mode="latest",
filter_command=False,
)
- result = [self._message_to_dict(msg, target_group_id) for msg in messages]
- logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
+ # 构建 message_id -> SessionMessage 索引,用于回填历史中 reply 段的发送者/内容
+ # (DB 中通常只存了 target_message_id,target_message_content/sender_* 缺失)。
+ message_index: Dict[str, SessionMessage] = {}
+ for m in messages:
+ mid = getattr(m, "message_id", None)
+ if mid:
+ message_index[str(mid)] = m
+
+ result: List[Dict[str, Any]] = []
+ for msg in messages:
+ item = self._message_to_dict(msg, group_id)
+ segments = item.get("segments")
+ if segments:
+ self._enrich_reply_segments(segments, message_index, session_id)
+ result.append(item)
+ logger.debug(
+ f"从数据库加载了 {len(result)} 条聊天记录 (group_id={group_id}, user_id={user_id})"
+ )
return result
except Exception as exc:
logger.error(f"从数据库加载聊天记录失败: {exc}")
return []
- def clear_history(self, group_id: Optional[str] = None) -> int:
+ def clear_history(
+ self,
+ group_id: Optional[str] = None,
+ user_id: Optional[str] = None,
+ ) -> int:
"""清空指定会话的历史消息。
Args:
- group_id: 群组标识。
+ group_id: 群组标识(虚拟群聊模式)。
+ user_id: 用户标识(私聊模式)。
Returns:
int: 被删除的消息数量。
"""
- target_group_id = group_id or WEBUI_CHAT_GROUP_ID
- session_id = self._resolve_session_id(target_group_id)
+ session_id = self._resolve_session_id(group_id=group_id, user_id=user_id)
+ if session_id is None:
+ return 0
try:
with get_db_session() as session:
statement = delete(Messages).where(col(Messages.session_id) == session_id)
result = session.exec(statement)
deleted = result.rowcount or 0
- logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
+ logger.info(
+ f"已清空 {deleted} 条聊天记录 (group_id={group_id}, user_id={user_id})"
+ )
return deleted
except Exception as exc:
logger.error(f"清空聊天记录失败: {exc}")
@@ -174,30 +302,30 @@ class ChatConnectionManager:
self.group_sessions: Dict[str, Set[str]] = {}
self.user_sessions: Dict[str, Set[str]] = {}
- def _bind_group(self, session_id: str, group_id: str) -> None:
- """为会话绑定群组索引。
+ def _bind_channel(self, session_id: str, channel_key: str) -> None:
+ """为会话绑定逻辑频道索引。
Args:
session_id: 内部会话 ID。
- group_id: 群组标识。
+ channel_key: 频道键(``group:`` 或 ``private:``)。
"""
- group_session_ids = self.group_sessions.setdefault(group_id, set())
- group_session_ids.add(session_id)
+ channel_session_ids = self.group_sessions.setdefault(channel_key, set())
+ channel_session_ids.add(session_id)
- def _unbind_group(self, session_id: str, group_id: str) -> None:
- """移除会话与群组的索引关系。
+ def _unbind_channel(self, session_id: str, channel_key: str) -> None:
+ """移除会话与逻辑频道的索引关系。
Args:
session_id: 内部会话 ID。
- group_id: 群组标识。
+ channel_key: 频道键。
"""
- group_session_ids = self.group_sessions.get(group_id)
- if group_session_ids is None:
+ channel_session_ids = self.group_sessions.get(channel_key)
+ if channel_session_ids is None:
return
- group_session_ids.discard(session_id)
- if not group_session_ids:
- del self.group_sessions[group_id]
+ channel_session_ids.discard(session_id)
+ if not channel_session_ids:
+ del self.group_sessions[channel_key]
async def connect(
self,
@@ -220,18 +348,39 @@ class ChatConnectionManager:
virtual_config: 当前虚拟身份配置。
sender: 发送消息到前端的异步回调。
"""
+ channel_key = compute_channel_key(virtual_config, user_id)
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
+ if existing_session_id is not None and existing_session_id == session_id:
+ # 同一物理连接 + 前端会话重复打开(常见于 React StrictMode 双挂载或客户端去抖失败),
+ # 直接复用现有会话并仅刷新可变字段,避免反复断开/重连产生噪声日志。
+ existing = self.active_connections.get(existing_session_id)
+ if existing is not None:
+ if existing.channel_key != channel_key:
+ self._unbind_channel(existing_session_id, existing.channel_key)
+ self._bind_channel(existing_session_id, channel_key)
+ existing.channel_key = channel_key
+ existing.user_id = user_id
+ existing.user_name = user_name
+ existing.virtual_config = virtual_config
+ existing.sender = sender
+ logger.debug(
+ "WebUI 聊天会话复用: session=%s, connection=%s, client_session=%s, channel=%s",
+ session_id,
+ connection_id,
+ client_session_id,
+ channel_key,
+ )
+ return
if existing_session_id is not None:
self.disconnect(existing_session_id)
- active_group_id = get_current_group_id(virtual_config)
session_connection = ChatSessionConnection(
session_id=session_id,
connection_id=connection_id,
client_session_id=client_session_id,
user_id=user_id,
user_name=user_name,
- active_group_id=active_group_id,
+ channel_key=channel_key,
virtual_config=virtual_config,
sender=sender,
)
@@ -240,14 +389,14 @@ class ChatConnectionManager:
self.client_sessions[(connection_id, client_session_id)] = session_id
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
self.user_sessions.setdefault(user_id, set()).add(session_id)
- self._bind_group(session_id, active_group_id)
+ self._bind_channel(session_id, channel_key)
logger.info(
- "WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
+ "WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, channel=%s",
session_id,
connection_id,
client_session_id,
user_id,
- active_group_id,
+ channel_key,
)
def disconnect(self, session_id: str) -> None:
@@ -261,7 +410,7 @@ class ChatConnectionManager:
return
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
- self._unbind_group(session_id, session_connection.active_group_id)
+ self._unbind_channel(session_id, session_connection.channel_key)
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
if connection_session_ids is not None:
@@ -327,11 +476,11 @@ class ChatConnectionManager:
if session_connection is None:
return
- next_group_id = get_current_group_id(virtual_config)
- if next_group_id != session_connection.active_group_id:
- self._unbind_group(session_id, session_connection.active_group_id)
- self._bind_group(session_id, next_group_id)
- session_connection.active_group_id = next_group_id
+ next_channel_key = compute_channel_key(virtual_config, session_connection.user_id)
+ if next_channel_key != session_connection.channel_key:
+ self._unbind_channel(session_id, session_connection.channel_key)
+ self._bind_channel(session_id, next_channel_key)
+ session_connection.channel_key = next_channel_key
session_connection.user_name = user_name
session_connection.virtual_config = virtual_config
@@ -361,16 +510,40 @@ class ChatConnectionManager:
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
- async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
- """向指定群组下的全部逻辑会话广播消息。
+ async def broadcast_to_channel(self, channel_key: str, message: Dict[str, Any]) -> None:
+ """向指定逻辑频道下的全部会话广播消息。
Args:
- group_id: 群组标识。
+ channel_key: 频道键(``group:`` 或 ``private:``)。
message: 待广播的消息内容。
"""
- for session_id in list(self.group_sessions.get(group_id, set())):
+ for session_id in list(self.group_sessions.get(channel_key, set())):
await self.send_message(session_id, message)
+ async def broadcast_to_group(
+ self,
+ group_id: Optional[str],
+ message: Dict[str, Any],
+ *,
+ user_id: Optional[str] = None,
+ ) -> None:
+ """向指定群组或私聊会话广播消息。
+
+ 当 ``group_id`` 非空时按群聊广播;否则按 ``user_id`` 私聊广播。
+
+ Args:
+ group_id: 群组标识;为空时使用 ``user_id``。
+ message: 待广播的消息内容。
+ user_id: 私聊接收方用户 ID。
+ """
+ if group_id:
+ channel_key = f"group:{group_id}"
+ elif user_id:
+ channel_key = f"private:{user_id}"
+ else:
+ return
+ await self.broadcast_to_channel(channel_key, message)
+
chat_history = ChatHistoryManager()
chat_manager = ChatConnectionManager()
@@ -388,6 +561,24 @@ def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) ->
return bool(virtual_config and virtual_config.enabled)
+def compute_channel_key(virtual_config: Optional[VirtualIdentityConfig], user_id: str) -> str:
+ """计算当前会话的逻辑频道键。
+
+ 虚拟身份启用时使用虚拟群聊 ID,否则使用当前 WebUI 用户 ID 作为私聊频道。
+
+ Args:
+ virtual_config: 虚拟身份配置。
+ user_id: 当前 WebUI 用户 ID。
+
+ Returns:
+ str: 频道键,格式为 ``group:`` 或 ``private:``。
+ """
+ if is_virtual_mode_enabled(virtual_config):
+ assert virtual_config is not None
+ return f"group:{virtual_config.group_id}"
+ return f"private:{user_id}"
+
+
def normalize_webui_user_id(user_id: Optional[str]) -> str:
"""标准化 WebUI 用户 ID。
@@ -500,6 +691,8 @@ def build_session_info_message(
Returns:
Dict[str, Any]: 会话信息消息。
"""
+ # bot_qq 用于前端从 QQ 头像公开接口拉取机器人头像(qq_account == 0 表示未配置,不推送)。
+ bot_qq_account = int(getattr(global_config.bot, "qq_account", 0) or 0)
session_info_data: Dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
@@ -507,6 +700,8 @@ def build_session_info_message(
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
+ if bot_qq_account > 0:
+ session_info_data["bot_qq"] = str(bot_qq_account)
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
@@ -529,7 +724,7 @@ def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig])
virtual_config: 虚拟身份配置。
Returns:
- Optional[str]: 虚拟身份启用时返回对应群组 ID。
+ Optional[str]: 虚拟身份启用时返回对应群组 ID;否则返回 ``None`` 表示使用私聊。
"""
if is_virtual_mode_enabled(virtual_config):
assert virtual_config is not None
@@ -537,16 +732,16 @@ def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig])
return None
-def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
+def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
"""获取当前会话的有效群组 ID。
Args:
virtual_config: 虚拟身份配置。
Returns:
- str: 当前会话应使用的群组 ID。
+ Optional[str]: 虚拟身份启用时返回对应群组 ID;否则返回 ``None``(默认私聊模式)。
"""
- return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
+ return get_active_history_group_id(virtual_config)
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
@@ -611,7 +806,12 @@ async def send_initial_chat_state(
)
history_group_id = get_active_history_group_id(virtual_config)
- history = chat_history.get_history(50, history_group_id)
+ history_user_id = None if history_group_id else user_id
+ history = chat_history.get_history(
+ 50,
+ group_id=history_group_id,
+ user_id=history_user_id,
+ )
await chat_manager.send_message(
session_id,
{
@@ -679,37 +879,42 @@ def create_message_data(
if virtual_config and virtual_config.enabled:
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
- group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
- group_name = virtual_config.group_name or "WebUI虚拟群聊"
+ group_id: Optional[str] = (
+ virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
+ )
+ group_name: Optional[str] = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
- actual_user_name = virtual_config.user_nickname or user_name
+ actual_user_nickname = virtual_config.user_nickname or user_name
else:
platform = WEBUI_CHAT_PLATFORM
- group_id = WEBUI_CHAT_GROUP_ID
- group_name = "WebUI本地聊天室"
+ group_id = None
+ group_name = None
actual_user_id = user_id
- actual_user_name = user_name
+ actual_user_nickname = user_name
+
+ message_info: Dict[str, Any] = {
+ "platform": platform,
+ "message_id": message_id,
+ "time": time.time(),
+ "user_info": {
+ "user_id": actual_user_id,
+ "user_nickname": actual_user_nickname,
+ "user_cardname": actual_user_nickname,
+ "platform": platform,
+ },
+ "additional_config": {
+ "at_bot": is_at_bot,
+ },
+ }
+ if group_id is not None:
+ message_info["group_info"] = {
+ "group_id": group_id,
+ "group_name": group_name,
+ "platform": platform,
+ }
return {
- "message_info": {
- "platform": platform,
- "message_id": message_id,
- "time": time.time(),
- "group_info": {
- "group_id": group_id,
- "group_name": group_name,
- "platform": platform,
- },
- "user_info": {
- "user_id": actual_user_id,
- "user_nickname": actual_user_name,
- "user_cardname": actual_user_name,
- "platform": platform,
- },
- "additional_config": {
- "at_bot": is_at_bot,
- },
- },
+ "message_info": message_info,
"message_segment": {
"type": "seglist",
"data": [
@@ -717,10 +922,6 @@ def create_message_data(
"type": "text",
"data": content,
},
- {
- "type": "mention_bot",
- "data": "1.0",
- },
],
},
"raw_message": content,
@@ -776,6 +977,7 @@ async def handle_chat_message(
},
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
},
+ user_id=normalized_user_id,
)
message_data = create_message_data(
@@ -788,13 +990,21 @@ async def handle_chat_message(
)
try:
- await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
+ await chat_manager.broadcast_to_group(
+ target_group_id,
+ {"type": "typing", "is_typing": True},
+ user_id=normalized_user_id,
+ )
await chat_bot.message_process(message_data)
except Exception as exc:
logger.error(f"处理消息时出错: {exc}")
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
finally:
- await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
+ await chat_manager.broadcast_to_group(
+ target_group_id,
+ {"type": "typing", "is_typing": False},
+ user_id=normalized_user_id,
+ )
return next_user_name
@@ -915,11 +1125,12 @@ async def enable_virtual_identity(
return None
-async def disable_virtual_identity(session_id: str) -> None:
+async def disable_virtual_identity(session_id: str, normalized_user_id: str) -> None:
"""关闭虚拟身份模式。
Args:
session_id: 内部逻辑会话 ID。
+ normalized_user_id: 规范化后的 WebUI 用户 ID,用于加载私聊历史。
"""
await chat_manager.send_message(
session_id,
@@ -933,8 +1144,8 @@ async def disable_virtual_identity(session_id: str) -> None:
session_id,
{
"type": "history",
- "messages": chat_history.get_history(50, WEBUI_CHAT_GROUP_ID),
- "group_id": WEBUI_CHAT_GROUP_ID,
+ "messages": chat_history.get_history(50, user_id=normalized_user_id),
+ "group_id": None,
},
)
await chat_manager.send_message(
@@ -952,6 +1163,7 @@ async def handle_virtual_identity_update(
session_id_prefix: str,
data: Dict[str, Any],
current_virtual_config: Optional[VirtualIdentityConfig],
+ normalized_user_id: str,
) -> Optional[VirtualIdentityConfig]:
"""处理虚拟身份切换请求。
@@ -960,6 +1172,7 @@ async def handle_virtual_identity_update(
session_id_prefix: 会话前缀。
data: 前端提交的数据。
current_virtual_config: 当前虚拟身份配置。
+ normalized_user_id: 规范化后的 WebUI 用户 ID。
Returns:
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
@@ -969,7 +1182,7 @@ async def handle_virtual_identity_update(
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
return next_config if next_config is not None else current_virtual_config
- await disable_virtual_identity(session_id)
+ await disable_virtual_identity(session_id, normalized_user_id)
return None
@@ -1019,6 +1232,7 @@ async def dispatch_chat_event(
session_id_prefix=session_id_prefix,
data=data,
current_virtual_config=current_virtual_config,
+ normalized_user_id=normalized_user_id,
)
return current_user_name, next_virtual_config