perf:优化webui交互体验,优化统计逻辑,优化log展示
This commit is contained in:
@@ -3139,9 +3139,7 @@ class SDKMemoryKernel:
|
||||
return {"success": False, "queued": False, "reason": "db_save_failed"}
|
||||
|
||||
logger.debug(
|
||||
"反馈纠错任务入队: query_tool_id=%s due_at=%s",
|
||||
clean_tool_id,
|
||||
due_at.isoformat(),
|
||||
f"反馈纠错任务入队: query_tool_id={clean_tool_id} due_at={due_at.isoformat()}",
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
@@ -162,20 +162,16 @@ class MetadataStore:
|
||||
def _run_runtime_auto_migration(self, *, current_version: int) -> None:
|
||||
"""对 1.0 之后的已版本化库执行轻量自动迁移。"""
|
||||
logger.info(
|
||||
"检测到 metadata schema 需要运行时自动迁移: current=%s, target=%s",
|
||||
current_version,
|
||||
SCHEMA_VERSION,
|
||||
f"检测到 metadata schema 需要运行时自动迁移: current={current_version}, target={SCHEMA_VERSION}",
|
||||
)
|
||||
self._migrate_schema()
|
||||
alias_result = self.rebuild_relation_hash_aliases()
|
||||
knowledge_type_result = self.normalize_paragraph_knowledge_types()
|
||||
self.set_schema_version(SCHEMA_VERSION)
|
||||
logger.info(
|
||||
"metadata schema 运行时自动迁移完成: %s -> %s, alias_inserted=%s, knowledge_normalized=%s",
|
||||
current_version,
|
||||
SCHEMA_VERSION,
|
||||
int(alias_result.get("inserted", 0) or 0),
|
||||
int(knowledge_type_result.get("normalized", 0) or 0),
|
||||
f"metadata schema 运行时自动迁移完成: {current_version} -> {SCHEMA_VERSION}, "
|
||||
f"alias_inserted={int(alias_result.get('inserted', 0) or 0)}, "
|
||||
f"knowledge_normalized={int(knowledge_type_result.get('normalized', 0) or 0)}",
|
||||
)
|
||||
|
||||
def _ensure_memory_feedback_task_columns(self, cursor: sqlite3.Cursor) -> None:
|
||||
|
||||
@@ -3126,7 +3126,7 @@ class ImportTaskManager:
|
||||
) -> None:
|
||||
content = str(processed.chunk.text or "")
|
||||
if is_probable_hash_token(content):
|
||||
logger.warning("跳过疑似哈希段落写入: source=%s preview=%s", self._source_label(file_record), content[:32])
|
||||
logger.warning(f"跳过疑似哈希段落写入: source={self._source_label(file_record)} preview={content[:32]}")
|
||||
return
|
||||
para_hash = self.plugin.metadata_store.add_paragraph(
|
||||
content=content,
|
||||
@@ -3208,10 +3208,7 @@ class ImportTaskManager:
|
||||
return ""
|
||||
if any(is_probable_hash_token(token) for token in (subject_token, predicate_token, object_token)):
|
||||
logger.warning(
|
||||
"跳过疑似哈希关系写入: %s | %s | %s",
|
||||
subject_token[:24],
|
||||
predicate_token[:24],
|
||||
object_token[:24],
|
||||
f"跳过疑似哈希关系写入: {subject_token[:24]} | {predicate_token[:24]} | {object_token[:24]}",
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
@@ -309,7 +309,7 @@ class AMemorixHostService:
|
||||
try:
|
||||
config_model = _get_config_manager().get_global_config().a_memorix
|
||||
except Exception as exc:
|
||||
logger.warning("读取 A_Memorix 主配置失败,使用默认值: %s", exc)
|
||||
logger.warning(f"读取 A_Memorix 主配置失败,使用默认值: {exc}")
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
@@ -8,13 +8,21 @@ from typing import cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from sqlmodel import col, func, select
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
|
||||
from src.common.database.database_model import OnlineTime
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.services.statistics_service import (
|
||||
fetch_messages_since,
|
||||
fetch_model_usage_since,
|
||||
fetch_online_time_since,
|
||||
fetch_tool_records_since,
|
||||
get_earliest_statistics_time,
|
||||
refresh_dashboard_statistics_cache,
|
||||
)
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
@@ -249,7 +257,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
deploy_time = datetime(2000, 1, 1)
|
||||
local_storage["deploy_time"] = now.timestamp()
|
||||
|
||||
self.all_time_start_time = self._get_all_time_start_time(deploy_time)
|
||||
self.all_time_start_time = get_earliest_statistics_time(deploy_time)
|
||||
|
||||
self.stat_period: list[tuple[str, timedelta, str]] = [
|
||||
("all_time", now - self.all_time_start_time, "自部署以来"), # 必须保留"all_time"
|
||||
@@ -265,23 +273,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
统计时间段 [(统计名称, 统计时间段, 统计描述), ...]
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_all_time_start_time(fallback_time: datetime) -> datetime:
|
||||
"""获取统计数据的最早时间,避免全量统计展示窗口漏掉历史数据。"""
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
start_times = [
|
||||
session.exec(select(func.min(ModelUsage.timestamp))).first(),
|
||||
session.exec(select(func.min(Messages.timestamp))).first(),
|
||||
session.exec(select(func.min(OnlineTime.start_timestamp))).first(),
|
||||
session.exec(select(func.min(ToolRecord.timestamp))).first(),
|
||||
]
|
||||
valid_start_times = [item for item in start_times if isinstance(item, datetime)]
|
||||
if valid_start_times:
|
||||
return min(valid_start_times)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取全量统计起始时间失败,将使用部署时间:{e}")
|
||||
return fallback_time
|
||||
|
||||
def _statistic_console_output(self, stats: StatPeriodMapping, now: datetime) -> None:
|
||||
"""
|
||||
@@ -324,6 +315,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 等待数据收集完成
|
||||
stats = await collect_task
|
||||
try:
|
||||
await refresh_dashboard_statistics_cache()
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新 WebUI 统计缓存失败,将继续生成 HTML 报告: {e}")
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 并行执行控制台输出和HTML报告生成
|
||||
@@ -354,6 +349,10 @@ class StatisticOutputTask(AsyncTask):
|
||||
logger.info("正在后台收集统计数据...")
|
||||
|
||||
stats = await loop.run_in_executor(executor, self._collect_all_statistics, now)
|
||||
try:
|
||||
await refresh_dashboard_statistics_cache()
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新 WebUI 统计缓存失败,将继续生成 HTML 报告: {e}")
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 创建并发的输出任务
|
||||
@@ -453,33 +452,6 @@ class StatisticOutputTask(AsyncTask):
|
||||
counter = cast(defaultdict[str, list[float]], stats_period[key])
|
||||
counter[subkey].append(value)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
{
|
||||
"timestamp": record.timestamp,
|
||||
"request_type": record.request_type,
|
||||
"model_api_provider_name": record.model_api_provider_name,
|
||||
"model_assign_name": record.model_assign_name,
|
||||
"model_name": record.model_name,
|
||||
"prompt_tokens": record.prompt_tokens,
|
||||
"completion_tokens": record.completion_tokens,
|
||||
"cost": record.cost,
|
||||
"time_cost": record.time_cost,
|
||||
}
|
||||
for record in records
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> StatPeriodMapping:
|
||||
"""
|
||||
@@ -500,7 +472,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 以最早的时间戳为起始时间获取记录
|
||||
# Assuming LLMUsage.timestamp is a DateTimeField
|
||||
query_start_time = collect_period[-1][1]
|
||||
records = StatisticOutputTask._fetch_model_usage_since(query_start_time)
|
||||
records = fetch_model_usage_since(query_start_time)
|
||||
for record in records:
|
||||
record_timestamp = cast(datetime, record["timestamp"])
|
||||
for idx, (_, period_start) in enumerate(collect_period):
|
||||
@@ -647,7 +619,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
query_start_time = collect_period[-1][1]
|
||||
# Assuming OnlineTime.end_timestamp is a DateTimeField
|
||||
records = StatisticOutputTask._fetch_online_time_since(query_start_time)
|
||||
records = fetch_online_time_since(query_start_time)
|
||||
for record_start_timestamp, record_end_timestamp in records:
|
||||
for idx, (_, period_boundary_start) in enumerate(collect_period):
|
||||
if record_end_timestamp >= period_boundary_start:
|
||||
@@ -688,9 +660,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
}
|
||||
|
||||
query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp)
|
||||
messages = session.exec(statement).all()
|
||||
messages = fetch_messages_since(query_start_timestamp)
|
||||
for message in messages:
|
||||
message_time_ts = message.timestamp.timestamp()
|
||||
|
||||
@@ -737,9 +707,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
# 使用 ToolRecord 中的 reply 工具次数作为回复数基准
|
||||
try:
|
||||
tool_query_start_timestamp = collect_period[-1][1]
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ToolRecord).where(col(ToolRecord.timestamp) >= tool_query_start_timestamp)
|
||||
tool_records = session.exec(statement).all()
|
||||
tool_records = fetch_tool_records_since(tool_query_start_timestamp)
|
||||
for tool_record in tool_records:
|
||||
if tool_record.tool_name != "reply":
|
||||
continue
|
||||
@@ -865,7 +833,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
"module": defaultdict(lambda: {"count": 0.0, "sum": 0.0, "sum_sq": 0.0}),
|
||||
}
|
||||
|
||||
records = self._fetch_model_usage_since(self.all_time_start_time)
|
||||
records = fetch_model_usage_since(self.all_time_start_time)
|
||||
for record in records:
|
||||
time_cost = cast(float | None, record["time_cost"]) or 0.0
|
||||
if time_cost <= 0:
|
||||
@@ -1806,7 +1774,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
records = StatisticOutputTask._fetch_model_usage_since(query_start_time)
|
||||
records = fetch_model_usage_since(query_start_time)
|
||||
for record in records:
|
||||
record_time = cast(datetime, record["timestamp"])
|
||||
|
||||
@@ -1835,9 +1803,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
messages = fetch_messages_since(start_time)
|
||||
for message in messages:
|
||||
message_time_ts = message.timestamp.timestamp()
|
||||
|
||||
@@ -2197,7 +2163,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询LLM使用记录
|
||||
query_start_time = start_time
|
||||
records = StatisticOutputTask._fetch_model_usage_since(query_start_time)
|
||||
records = fetch_model_usage_since(query_start_time)
|
||||
for record in records:
|
||||
record_time = cast(datetime, record["timestamp"])
|
||||
|
||||
@@ -2216,9 +2182,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
|
||||
# 查询消息记录
|
||||
query_start_timestamp = start_time.timestamp()
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= start_time)
|
||||
messages = session.exec(statement).all()
|
||||
messages = fetch_messages_since(start_time)
|
||||
for message in messages:
|
||||
message_time_ts = message.timestamp.timestamp()
|
||||
|
||||
@@ -2232,7 +2196,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
total_replies[interval_index] += 1
|
||||
|
||||
# 查询在线时间记录
|
||||
records = StatisticOutputTask._fetch_online_time_since(start_time)
|
||||
records = fetch_online_time_since(start_time)
|
||||
for record_start, record_end in records:
|
||||
# 找到记录覆盖的所有时间间隔
|
||||
for idx, time_point in enumerate(time_points):
|
||||
@@ -2491,6 +2455,10 @@ class AsyncStatisticOutputTask(AsyncTask):
|
||||
|
||||
# 数据收集任务
|
||||
stats = await loop.run_in_executor(executor, self._statistic_task._collect_all_statistics, now)
|
||||
try:
|
||||
await refresh_dashboard_statistics_cache()
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新 WebUI 统计缓存失败,将继续生成 HTML 报告: {e}")
|
||||
logger.info("统计数据收集完成")
|
||||
|
||||
# 创建并发的输出任务
|
||||
|
||||
@@ -134,7 +134,7 @@ def _read_metadata_file(metadata_path: Path) -> dict[str, Any]:
|
||||
else:
|
||||
metadata = parse_toml(metadata_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
logger.warning("读取 Prompt 元信息文件 %s 失败:%s", metadata_path, exc)
|
||||
logger.warning(f"读取 Prompt 元信息文件 {metadata_path} 失败:{exc}")
|
||||
return {}
|
||||
|
||||
return dict(metadata) if isinstance(metadata, dict) else {}
|
||||
|
||||
@@ -55,7 +55,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0-pre.11"
|
||||
MMC_VERSION: str = "1.0.0-pre.13"
|
||||
CONFIG_VERSION: str = "8.10.9"
|
||||
MODEL_CONFIG_VERSION: str = "1.15.3"
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class BotConfig(ConfigBase):
|
||||
"x-icon": "user-circle",
|
||||
},
|
||||
)
|
||||
"""机器人昵称"""
|
||||
""""""
|
||||
|
||||
alias_names: list[str] = Field(
|
||||
default_factory=lambda: [],
|
||||
@@ -130,6 +130,7 @@ class PersonalityConfig(ConfigBase):
|
||||
"x-icon": "user-circle",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
"x-description-display": "icon",
|
||||
},
|
||||
)
|
||||
"""人格,建议200字以内,描述人格特质和身份特征;可以写完整设定。要求第二人称"""
|
||||
@@ -146,6 +147,7 @@ class PersonalityConfig(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
"x-description-display": "icon",
|
||||
},
|
||||
)
|
||||
"""默认表达风格,描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行"""
|
||||
@@ -514,6 +516,7 @@ class MessageReceiveConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "image",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -419,9 +419,7 @@ class ToolRegistry:
|
||||
return await provider.invoke(invocation, context)
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"工具调用异常: tool=%s provider=%s",
|
||||
invocation.tool_name,
|
||||
getattr(provider, "provider_name", ""),
|
||||
f"工具调用异常: tool={invocation.tool_name} provider={getattr(provider, 'provider_name', '')}",
|
||||
)
|
||||
error_message = str(exc).strip()
|
||||
if error_message:
|
||||
|
||||
@@ -768,11 +768,8 @@ class EmojiManager:
|
||||
selected_emoji, similarity = random.choice(top_emojis)
|
||||
self.update_emoji_usage(selected_emoji)
|
||||
logger.info(
|
||||
"[获取表情包] 为[%s]选中表情包: %s(%s),相似度: %.4f",
|
||||
emotion_label,
|
||||
selected_emoji.file_name,
|
||||
",".join(_get_emoji_emotions(selected_emoji)),
|
||||
similarity,
|
||||
f"[获取表情包] 为[{emotion_label}]选中表情包: "
|
||||
f"{selected_emoji.file_name}({','.join(_get_emoji_emotions(selected_emoji))}),相似度: {similarity:.4f}",
|
||||
)
|
||||
return selected_emoji
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ class MainSystem:
|
||||
|
||||
await config_manager.start_file_watcher()
|
||||
a_memorix_host_service.register_config_reload_callback()
|
||||
prompt_manager.load_prompts()
|
||||
|
||||
# 添加在线时间统计任务
|
||||
await async_task_manager.add_task(OnlineTimeRecordTask())
|
||||
@@ -121,8 +122,6 @@ class MainSystem:
|
||||
self.app.register_message_handler(chat_bot.message_process)
|
||||
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
|
||||
|
||||
prompt_manager.load_prompts()
|
||||
|
||||
# 触发 ON_START 事件
|
||||
from src.core.event_bus import event_bus
|
||||
from src.core.types import EventType
|
||||
@@ -161,10 +160,8 @@ async def main() -> None:
|
||||
"""主函数"""
|
||||
system = MainSystem()
|
||||
try:
|
||||
await asyncio.gather(
|
||||
system.initialize(),
|
||||
system.schedule_tasks(),
|
||||
)
|
||||
await system.initialize()
|
||||
await system.schedule_tasks()
|
||||
finally:
|
||||
disable_stage_status_board()
|
||||
emoji_manager.shutdown()
|
||||
|
||||
@@ -191,7 +191,7 @@ class PluginMessageUtils:
|
||||
return ""
|
||||
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
||||
except Exception as exc:
|
||||
logger.debug("通过 hash 加载历史媒体失败: type=%s hash=%s error=%s", image_type, binary_hash, exc)
|
||||
logger.debug(f"通过 hash 加载历史媒体失败: type={image_type} hash={binary_hash} error={exc}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -48,7 +48,7 @@ class PersonFactWritebackService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭人物事实写回 worker 失败: %s", exc)
|
||||
logger.warning(f"关闭人物事实写回 worker 失败: {exc}")
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(global_config.a_memorix.integration.person_fact_writeback_enabled):
|
||||
@@ -67,7 +67,7 @@ class PersonFactWritebackService:
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("人物事实写回处理失败: %s", exc, exc_info=True)
|
||||
logger.warning(f"人物事实写回处理失败: {exc}", exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
@@ -126,7 +126,7 @@ class PersonFactWritebackService:
|
||||
try:
|
||||
replies = find_messages(message_id=reply_to, limit=1)
|
||||
except Exception as exc:
|
||||
logger.debug("查询 reply_to 目标失败: %s", exc)
|
||||
logger.debug(f"查询 reply_to 目标失败: {exc}")
|
||||
return None
|
||||
if not replies:
|
||||
return None
|
||||
@@ -166,7 +166,7 @@ class PersonFactWritebackService:
|
||||
try:
|
||||
response_result = await self._extractor.generate_response(prompt)
|
||||
except Exception as exc:
|
||||
logger.debug("人物事实提取模型调用失败: %s", exc)
|
||||
logger.debug(f"人物事实提取模型调用失败: {exc}")
|
||||
return []
|
||||
return self._parse_fact_list(response_result.response)
|
||||
|
||||
@@ -248,7 +248,7 @@ class ChatSummaryWritebackService:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("关闭聊天摘要写回 worker 失败: %s", exc)
|
||||
logger.warning(f"关闭聊天摘要写回 worker 失败: {exc}")
|
||||
|
||||
async def enqueue(self, message: Any) -> None:
|
||||
if not bool(global_config.a_memorix.integration.chat_summary_writeback_enabled):
|
||||
@@ -267,7 +267,7 @@ class ChatSummaryWritebackService:
|
||||
try:
|
||||
await self._handle_message(message)
|
||||
except Exception as exc:
|
||||
logger.warning("聊天摘要写回处理失败: %s", exc, exc_info=True)
|
||||
logger.warning(f"聊天摘要写回处理失败: {exc}", exc_info=True)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
@@ -319,21 +319,16 @@ class ChatSummaryWritebackService:
|
||||
)
|
||||
if not getattr(result, "success", False):
|
||||
logger.warning(
|
||||
"聊天摘要自动写回失败: session_id=%s detail=%s",
|
||||
session_id,
|
||||
getattr(result, "detail", ""),
|
||||
f"聊天摘要自动写回失败: session_id={session_id} detail={getattr(result, 'detail', '')}",
|
||||
)
|
||||
return
|
||||
|
||||
state.last_trigger_message_count = total_message_count
|
||||
state.last_trigger_time = time.time()
|
||||
logger.info(
|
||||
"聊天摘要自动写回成功: session_id=%s trigger=%s total_messages=%s context_length=%s detail=%s",
|
||||
session_id,
|
||||
"message_threshold",
|
||||
total_message_count,
|
||||
context_length,
|
||||
getattr(result, "detail", ""),
|
||||
f"聊天摘要自动写回成功: session_id={session_id} trigger=message_threshold "
|
||||
f"total_messages={total_message_count} context_length={context_length} "
|
||||
f"detail={getattr(result, 'detail', '')}",
|
||||
)
|
||||
|
||||
async def _load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||
@@ -363,7 +358,7 @@ class ChatSummaryWritebackService:
|
||||
# 至少避免重启后立刻重复写入一条相近摘要。
|
||||
return total_message_count
|
||||
except Exception as exc:
|
||||
logger.debug("恢复聊天摘要写回游标失败: session_id=%s error=%s", session_id, exc)
|
||||
logger.debug(f"恢复聊天摘要写回游标失败: session_id={session_id} error={exc}")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -230,7 +230,7 @@ class MemoryService:
|
||||
)
|
||||
return self._coerce_search_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆搜索失败: %s", exc)
|
||||
logger.warning(f"长期记忆搜索失败: {exc}")
|
||||
return MemorySearchResult(success=False, error=str(exc))
|
||||
|
||||
async def enqueue_feedback_task(
|
||||
@@ -253,7 +253,7 @@ class MemoryService:
|
||||
timeout_ms=10000,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("反馈纠错任务入队失败: %s", exc)
|
||||
logger.warning(f"反馈纠错任务入队失败: {exc}")
|
||||
return {"success": False, "queued": False, "reason": str(exc)}
|
||||
return payload if isinstance(payload, dict) else {"success": False, "queued": False, "reason": "invalid_payload"}
|
||||
|
||||
@@ -291,7 +291,7 @@ class MemoryService:
|
||||
)
|
||||
return self._coerce_write_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆写入摘要失败: %s", exc)
|
||||
logger.warning(f"长期记忆写入摘要失败: {exc}")
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def ingest_text(
|
||||
@@ -338,7 +338,7 @@ class MemoryService:
|
||||
)
|
||||
return self._coerce_write_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("长期记忆写入文本失败: %s", exc)
|
||||
logger.warning(f"长期记忆写入文本失败: {exc}")
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult:
|
||||
@@ -352,7 +352,7 @@ class MemoryService:
|
||||
)
|
||||
return self._coerce_profile_result(payload)
|
||||
except Exception as exc:
|
||||
logger.warning("获取人物画像失败: %s", exc)
|
||||
logger.warning(f"获取人物画像失败: {exc}")
|
||||
return PersonProfileResult()
|
||||
|
||||
async def maintain_memory(
|
||||
@@ -373,7 +373,7 @@ class MemoryService:
|
||||
return MemoryWriteResult(success=False, detail="invalid_payload")
|
||||
return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or ""))
|
||||
except Exception as exc:
|
||||
logger.warning("记忆维护失败: %s", exc)
|
||||
logger.warning(f"记忆维护失败: {exc}")
|
||||
return MemoryWriteResult(success=False, detail=str(exc))
|
||||
|
||||
async def memory_stats(self) -> Dict[str, Any]:
|
||||
@@ -381,77 +381,77 @@ class MemoryService:
|
||||
payload = await self._invoke("memory_stats", {})
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
except Exception as exc:
|
||||
logger.warning("获取记忆统计失败: %s", exc)
|
||||
logger.warning(f"获取记忆统计失败: {exc}")
|
||||
return {}
|
||||
|
||||
async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_graph_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("图谱管理调用失败: %s", exc)
|
||||
logger.warning(f"图谱管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_source_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("来源管理调用失败: %s", exc)
|
||||
logger.warning(f"来源管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_episode_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Episode 管理调用失败: %s", exc)
|
||||
logger.warning(f"Episode 管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_profile_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("画像管理调用失败: %s", exc)
|
||||
logger.warning(f"画像管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def feedback_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_feedback_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("反馈纠错管理调用失败: %s", exc)
|
||||
logger.warning(f"反馈纠错管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("运行时管理调用失败: %s", exc)
|
||||
logger.warning(f"运行时管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("导入管理调用失败: %s", exc)
|
||||
logger.warning(f"导入管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("调优管理调用失败: %s", exc)
|
||||
logger.warning(f"调优管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("V5 记忆管理调用失败: %s", exc)
|
||||
logger.warning(f"V5 记忆管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]:
|
||||
try:
|
||||
return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("删除管理调用失败: %s", exc)
|
||||
logger.warning(f"删除管理调用失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]:
|
||||
@@ -459,7 +459,7 @@ class MemoryService:
|
||||
payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))})
|
||||
return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"}
|
||||
except Exception as exc:
|
||||
logger.warning("获取回收站失败: %s", exc)
|
||||
logger.warning(f"获取回收站失败: {exc}")
|
||||
return {"success": False, "error": str(exc)}
|
||||
|
||||
async def restore_memory(self, *, target: str) -> MemoryWriteResult:
|
||||
|
||||
491
src/services/statistics_service.py
Normal file
491
src/services/statistics_service.py
Normal file
@@ -0,0 +1,491 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy import desc, func, or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.webui.schemas.statistics import DashboardData, ModelStatistics, StatisticsSummary, TimeSeriesData
|
||||
|
||||
logger = get_logger("statistics_service")
|
||||
|
||||
DASHBOARD_STATISTICS_CACHE_KEY = "webui_dashboard_statistics_cache"
|
||||
DASHBOARD_STATISTICS_CACHE_VERSION = 1
|
||||
DEFAULT_DASHBOARD_CACHE_MAX_AGE_SECONDS = 600
|
||||
DEFAULT_DASHBOARD_CACHE_HOURS = (24, 168, 720)
|
||||
_SPARSE_TIME_SERIES_FIELDS = ("hourly_data", "daily_data")
|
||||
|
||||
|
||||
async def get_dashboard_statistics(hours: int = 24, *, use_cache: bool = True) -> DashboardData:
|
||||
"""获取 WebUI 仪表盘统计数据。"""
|
||||
if use_cache:
|
||||
cached_data = get_cached_dashboard_statistics(hours)
|
||||
if cached_data is not None:
|
||||
return cached_data
|
||||
|
||||
return build_empty_dashboard_statistics()
|
||||
|
||||
|
||||
def build_empty_dashboard_statistics() -> DashboardData:
|
||||
"""构造空的 WebUI 仪表盘统计数据。"""
|
||||
return DashboardData(
|
||||
summary=StatisticsSummary(),
|
||||
model_stats=[],
|
||||
hourly_data=[],
|
||||
daily_data=[],
|
||||
recent_activity=[],
|
||||
)
|
||||
|
||||
|
||||
async def compute_dashboard_statistics(hours: int = 24) -> DashboardData:
|
||||
"""获取 WebUI 仪表盘统计数据。"""
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
summary = await get_summary_statistics(start_time, now)
|
||||
model_stats = await get_model_statistics(start_time)
|
||||
hourly_data = await get_hourly_statistics(start_time, now)
|
||||
daily_data = await get_daily_statistics(now - timedelta(days=7), now)
|
||||
recent_activity = await get_recent_activity(limit=10)
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
|
||||
|
||||
def get_cached_dashboard_statistics(
|
||||
hours: int = 24,
|
||||
*,
|
||||
max_age_seconds: int = DEFAULT_DASHBOARD_CACHE_MAX_AGE_SECONDS,
|
||||
) -> DashboardData | None:
|
||||
"""从本地快照读取 WebUI 仪表盘统计数据。"""
|
||||
raw_cache = local_storage[DASHBOARD_STATISTICS_CACHE_KEY]
|
||||
if not isinstance(raw_cache, dict):
|
||||
return None
|
||||
if raw_cache.get("version") != DASHBOARD_STATISTICS_CACHE_VERSION:
|
||||
return None
|
||||
|
||||
generated_at = raw_cache.get("generated_at")
|
||||
if not isinstance(generated_at, (int, float)):
|
||||
return None
|
||||
if datetime.now().timestamp() - float(generated_at) > max_age_seconds:
|
||||
return None
|
||||
|
||||
entries = raw_cache.get("entries")
|
||||
if not isinstance(entries, dict):
|
||||
return None
|
||||
|
||||
entry = entries.get(str(hours))
|
||||
if not isinstance(entry, dict):
|
||||
return None
|
||||
|
||||
try:
|
||||
expanded_entry = _expand_dashboard_cache_entry(entry, hours=hours, generated_at=float(generated_at))
|
||||
return DashboardData.model_validate(expanded_entry)
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 WebUI 统计缓存失败,将实时计算: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def store_dashboard_statistics_cache(entries: dict[int, DashboardData], *, generated_at: datetime | None = None) -> None:
|
||||
"""保存 WebUI 仪表盘统计数据快照。"""
|
||||
snapshot_time = generated_at or datetime.now()
|
||||
local_storage[DASHBOARD_STATISTICS_CACHE_KEY] = {
|
||||
"version": DASHBOARD_STATISTICS_CACHE_VERSION,
|
||||
"generated_at": snapshot_time.timestamp(),
|
||||
"entries": {str(hours): _compact_dashboard_cache_entry(data) for hours, data in entries.items()},
|
||||
}
|
||||
|
||||
|
||||
def update_dashboard_statistics_cache_entry(
|
||||
hours: int,
|
||||
data: DashboardData,
|
||||
*,
|
||||
generated_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""更新单个 WebUI 仪表盘统计缓存条目。"""
|
||||
raw_cache = local_storage[DASHBOARD_STATISTICS_CACHE_KEY]
|
||||
entries: dict[str, Any] = {}
|
||||
if isinstance(raw_cache, dict) and isinstance(raw_cache.get("entries"), dict):
|
||||
entries.update(raw_cache["entries"])
|
||||
|
||||
snapshot_time = generated_at or datetime.now()
|
||||
entries[str(hours)] = _compact_dashboard_cache_entry(data)
|
||||
local_storage[DASHBOARD_STATISTICS_CACHE_KEY] = {
|
||||
"version": DASHBOARD_STATISTICS_CACHE_VERSION,
|
||||
"generated_at": snapshot_time.timestamp(),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
|
||||
async def refresh_dashboard_statistics_cache(hours_values: tuple[int, ...] = DEFAULT_DASHBOARD_CACHE_HOURS) -> None:
|
||||
"""刷新 WebUI 仪表盘统计数据快照。"""
|
||||
cache_entries: dict[int, DashboardData] = {}
|
||||
for hours in hours_values:
|
||||
cache_entries[hours] = await compute_dashboard_statistics(hours=hours)
|
||||
store_dashboard_statistics_cache(cache_entries)
|
||||
|
||||
|
||||
def _compact_dashboard_cache_entry(data: DashboardData) -> dict[str, Any]:
|
||||
"""压缩 WebUI 仪表盘缓存条目,去掉全 0 时间桶。"""
|
||||
entry = data.model_dump(mode="json")
|
||||
for field_name in _SPARSE_TIME_SERIES_FIELDS:
|
||||
series = entry.get(field_name)
|
||||
if isinstance(series, list):
|
||||
entry[field_name] = [item for item in series if not _is_empty_time_series_item(item)]
|
||||
entry["sparse"] = True
|
||||
return entry
|
||||
|
||||
|
||||
def _expand_dashboard_cache_entry(entry: dict[str, Any], *, hours: int, generated_at: float) -> dict[str, Any]:
|
||||
"""将稀疏缓存条目展开为前端需要的完整时间序列。"""
|
||||
if entry.get("sparse") is not True:
|
||||
return entry
|
||||
|
||||
expanded = dict(entry)
|
||||
generated_datetime = datetime.fromtimestamp(generated_at)
|
||||
expanded["hourly_data"] = _expand_time_series(
|
||||
sparse_series=entry.get("hourly_data"),
|
||||
start_time=generated_datetime - timedelta(hours=hours),
|
||||
end_time=generated_datetime,
|
||||
step=timedelta(hours=1),
|
||||
timestamp_format="%Y-%m-%dT%H:00:00",
|
||||
)
|
||||
expanded["daily_data"] = _expand_time_series(
|
||||
sparse_series=entry.get("daily_data"),
|
||||
start_time=generated_datetime - timedelta(days=7),
|
||||
end_time=generated_datetime,
|
||||
step=timedelta(days=1),
|
||||
timestamp_format="%Y-%m-%dT00:00:00",
|
||||
)
|
||||
expanded.pop("sparse", None)
|
||||
return expanded
|
||||
|
||||
|
||||
def _expand_time_series(
|
||||
*,
|
||||
sparse_series: Any,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
step: timedelta,
|
||||
timestamp_format: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
sparse_items = sparse_series if isinstance(sparse_series, list) else []
|
||||
sparse_by_timestamp = {
|
||||
item.get("timestamp"): item
|
||||
for item in sparse_items
|
||||
if isinstance(item, dict) and isinstance(item.get("timestamp"), str)
|
||||
}
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
current = _floor_time_for_format(start_time, timestamp_format)
|
||||
while current <= end_time:
|
||||
timestamp = current.strftime(timestamp_format)
|
||||
item = sparse_by_timestamp.get(timestamp)
|
||||
if isinstance(item, dict):
|
||||
result.append(item)
|
||||
else:
|
||||
result.append({"timestamp": timestamp, "requests": 0, "cost": 0.0, "tokens": 0})
|
||||
current += step
|
||||
return result
|
||||
|
||||
|
||||
def _floor_time_for_format(value: datetime, timestamp_format: str) -> datetime:
|
||||
if "%H" in timestamp_format:
|
||||
return value.replace(minute=0, second=0, microsecond=0)
|
||||
return value.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
|
||||
def _is_empty_time_series_item(item: Any) -> bool:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
return (
|
||||
int(item.get("requests") or 0) == 0
|
||||
and float(item.get("cost") or 0.0) == 0.0
|
||||
and int(item.get("tokens") or 0) == 0
|
||||
)
|
||||
|
||||
|
||||
async def get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取指定时间范围内的摘要统计数据。"""
|
||||
summary = StatisticsSummary(
|
||||
total_requests=0,
|
||||
total_cost=0.0,
|
||||
total_tokens=0,
|
||||
online_time=0.0,
|
||||
total_messages=0,
|
||||
total_replies=0,
|
||||
avg_response_time=0.0,
|
||||
cost_per_hour=0.0,
|
||||
tokens_per_hour=0.0,
|
||||
)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(
|
||||
func.count().label("total_requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||
func.avg(col(ModelUsage.time_cost)).label("avg_response_time"),
|
||||
).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
result = session.exec(statement).first()
|
||||
|
||||
if result:
|
||||
total_requests, total_cost, total_tokens, avg_response_time = result
|
||||
summary.total_requests = total_requests or 0
|
||||
summary.total_cost = float(total_cost or 0.0)
|
||||
summary.total_tokens = total_tokens or 0
|
||||
summary.avg_response_time = float(avg_response_time or 0.0)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(
|
||||
or_(
|
||||
col(OnlineTime.start_timestamp) >= start_time,
|
||||
col(OnlineTime.end_timestamp) >= start_time,
|
||||
)
|
||||
)
|
||||
online_records = session.exec(statement).all()
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp())
|
||||
summary.total_replies = count_messages(
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=end_time.timestamp(),
|
||||
has_reply_to=True,
|
||||
)
|
||||
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取指定时间之后的模型统计数据。"""
|
||||
statement = (
|
||||
select(ModelUsage)
|
||||
.where(col(ModelUsage.timestamp) >= start_time)
|
||||
.order_by(desc(col(ModelUsage.timestamp)))
|
||||
.limit(200)
|
||||
)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
records = session.exec(statement).all()
|
||||
|
||||
aggregates: Dict[str, Dict[str, float | int]] = {}
|
||||
for record in records:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in aggregates:
|
||||
aggregates[model_name] = {
|
||||
"request_count": 0,
|
||||
"total_cost": 0.0,
|
||||
"total_tokens": 0,
|
||||
"total_time_cost": 0.0,
|
||||
"time_cost_count": 0,
|
||||
}
|
||||
|
||||
bucket = aggregates[model_name]
|
||||
bucket["request_count"] = int(bucket["request_count"]) + 1
|
||||
bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0)
|
||||
bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0)
|
||||
if record.time_cost:
|
||||
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||
|
||||
result: List[ModelStatistics] = []
|
||||
for model_name, bucket in sorted(
|
||||
aggregates.items(),
|
||||
key=lambda item: float(item[1]["request_count"]),
|
||||
reverse=True,
|
||||
)[:10]:
|
||||
time_cost_count = int(bucket["time_cost_count"])
|
||||
avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=int(bucket["request_count"]),
|
||||
total_cost=float(bucket["total_cost"]),
|
||||
total_tokens=int(bucket["total_tokens"]),
|
||||
avg_response_time=avg_time_cost,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""按小时聚合 LLM 请求、费用和 token。"""
|
||||
hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
hour_expr.label("hour"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(hour_expr)
|
||||
)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""按天聚合 LLM 请求、费用和 token。"""
|
||||
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
day_expr.label("day"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(day_expr)
|
||||
)
|
||||
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近的 LLM 调用记录。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": record.total_tokens or 0,
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": None,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
def fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]:
|
||||
"""获取指定时间之后仍有覆盖的在线时间区间。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [(record.start_timestamp, record.end_timestamp) for record in records]
|
||||
|
||||
|
||||
def fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]:
|
||||
"""获取指定时间之后的 LLM 使用记录。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time)
|
||||
records = session.exec(statement).all()
|
||||
return [
|
||||
{
|
||||
"timestamp": record.timestamp,
|
||||
"request_type": record.request_type,
|
||||
"model_api_provider_name": record.model_api_provider_name,
|
||||
"model_assign_name": record.model_assign_name,
|
||||
"model_name": record.model_name,
|
||||
"prompt_tokens": record.prompt_tokens,
|
||||
"completion_tokens": record.completion_tokens,
|
||||
"cost": record.cost,
|
||||
"time_cost": record.time_cost,
|
||||
}
|
||||
for record in records
|
||||
]
|
||||
|
||||
|
||||
def fetch_messages_since(query_start_time: datetime) -> list[Messages]:
|
||||
"""获取指定时间之后的消息记录。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Messages).where(col(Messages.timestamp) >= query_start_time)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def fetch_tool_records_since(query_start_time: datetime) -> list[ToolRecord]:
|
||||
"""获取指定时间之后的工具调用记录。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(ToolRecord).where(col(ToolRecord.timestamp) >= query_start_time)
|
||||
return list(session.exec(statement).all())
|
||||
|
||||
|
||||
def get_earliest_statistics_time(fallback_time: datetime) -> datetime:
|
||||
"""获取统计数据中最早的记录时间。"""
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
start_times = [
|
||||
session.exec(select(func.min(ModelUsage.timestamp))).first(),
|
||||
session.exec(select(func.min(Messages.timestamp))).first(),
|
||||
session.exec(select(func.min(OnlineTime.start_timestamp))).first(),
|
||||
session.exec(select(func.min(ToolRecord.timestamp))).first(),
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"获取全量统计起始时间失败,将使用回退时间: {e}")
|
||||
return fallback_time
|
||||
|
||||
valid_start_times = [item for item in start_times if isinstance(item, datetime)]
|
||||
if valid_start_times:
|
||||
return min(valid_start_times)
|
||||
return fallback_time
|
||||
@@ -360,11 +360,8 @@ class ChatConnectionManager:
|
||||
existing.virtual_config = virtual_config
|
||||
existing.sender = sender
|
||||
logger.debug(
|
||||
"WebUI 聊天会话复用: session=%s, connection=%s, client_session=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
channel_key,
|
||||
f"WebUI 聊天会话复用: session={session_id}, connection={connection_id}, "
|
||||
f"client_session={client_session_id}, channel={channel_key}",
|
||||
)
|
||||
return
|
||||
if existing_session_id is not None:
|
||||
@@ -387,12 +384,8 @@ class ChatConnectionManager:
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_channel(session_id, channel_key)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
channel_key,
|
||||
f"WebUI 聊天会话已连接: session={session_id}, connection={connection_id}, "
|
||||
f"client_session={client_session_id}, user={user_id}, channel={channel_key}",
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
@@ -420,7 +413,7 @@ class ChatConnectionManager:
|
||||
if not user_session_ids:
|
||||
del self.user_sessions[session_connection.user_id]
|
||||
|
||||
logger.info("WebUI 聊天会话已断开: session=%s", session_id)
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
def disconnect_connection(self, connection_id: str) -> None:
|
||||
"""断开物理连接下的全部逻辑聊天会话。
|
||||
@@ -495,7 +488,7 @@ class ChatConnectionManager:
|
||||
try:
|
||||
await session_connection.sender(message)
|
||||
except Exception as exc:
|
||||
logger.error("发送聊天消息失败: session=%s, error=%s", session_id, exc)
|
||||
logger.error(f"发送聊天消息失败: session={session_id}, error={exc}")
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]) -> None:
|
||||
"""向全部逻辑聊天会话广播消息。
|
||||
@@ -659,10 +652,8 @@ def resolve_initial_virtual_identity(
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
"虚拟身份模式已通过参数激活: %s @ %s, group_id=%s",
|
||||
virtual_config.user_nickname,
|
||||
virtual_config.platform,
|
||||
virtual_group_id,
|
||||
f"虚拟身份模式已通过参数激活: {virtual_config.user_nickname} @ "
|
||||
f"{virtual_config.platform}, group_id={virtual_group_id}",
|
||||
)
|
||||
return virtual_config
|
||||
except Exception as exc:
|
||||
|
||||
@@ -1,349 +1,34 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import desc, func, or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ModelUsage, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import count_messages
|
||||
from src.services.statistics_service import get_dashboard_statistics, get_model_statistics, get_summary_statistics
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.schemas.statistics import DashboardData
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"], dependencies=[Depends(require_auth)])
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
online_time: float = Field(0.0, description="在线时间(秒)")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
avg_response_time: float = Field(0.0, description="平均响应时间")
|
||||
cost_per_hour: float = Field(0.0, description="每小时花费")
|
||||
tokens_per_hour: float = Field(0.0, description="每小时token数")
|
||||
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
total_tokens: int
|
||||
avg_response_time: float
|
||||
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
tokens: int = 0
|
||||
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
daily_data: List[TimeSeriesData]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
async def get_dashboard_data(hours: int = 24) -> DashboardData:
|
||||
"""获取仪表盘统计数据。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
return await get_dashboard_statistics(hours=hours)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||
summary = StatisticsSummary(
|
||||
total_requests=0,
|
||||
total_cost=0.0,
|
||||
total_tokens=0,
|
||||
online_time=0.0,
|
||||
total_messages=0,
|
||||
total_replies=0,
|
||||
avg_response_time=0.0,
|
||||
cost_per_hour=0.0,
|
||||
tokens_per_hour=0.0,
|
||||
)
|
||||
|
||||
# 使用聚合查询替代全量加载
|
||||
with get_db_session() as session:
|
||||
statement = select(
|
||||
func.count().label("total_requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||
func.avg(col(ModelUsage.time_cost)).label("avg_response_time"),
|
||||
).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
result = session.exec(statement).first()
|
||||
|
||||
if result:
|
||||
total_requests, total_cost, total_tokens, avg_response_time = result
|
||||
summary.total_requests = total_requests or 0
|
||||
summary.total_cost = float(total_cost or 0.0)
|
||||
summary.total_tokens = total_tokens or 0
|
||||
summary.avg_response_time = float(avg_response_time or 0.0)
|
||||
|
||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||
with get_db_session() as session:
|
||||
statement = select(OnlineTime).where(
|
||||
or_(
|
||||
col(OnlineTime.start_timestamp) >= start_time,
|
||||
col(OnlineTime.end_timestamp) >= start_time,
|
||||
)
|
||||
)
|
||||
online_records = session.exec(statement).all()
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
summary.total_messages = count_messages(start_time=start_time.timestamp(), end_time=end_time.timestamp())
|
||||
summary.total_replies = count_messages(
|
||||
start_time=start_time.timestamp(),
|
||||
end_time=end_time.timestamp(),
|
||||
has_reply_to=True,
|
||||
)
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
statement = (
|
||||
select(ModelUsage)
|
||||
.where(col(ModelUsage.timestamp) >= start_time)
|
||||
.order_by(desc(col(ModelUsage.timestamp)))
|
||||
.limit(200)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
aggregates: Dict[str, Dict[str, float | int]] = {}
|
||||
for record in rows:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
if model_name not in aggregates:
|
||||
aggregates[model_name] = {
|
||||
"request_count": 0,
|
||||
"total_cost": 0.0,
|
||||
"total_tokens": 0,
|
||||
"total_time_cost": 0.0,
|
||||
"time_cost_count": 0,
|
||||
}
|
||||
bucket = aggregates[model_name]
|
||||
bucket["request_count"] = int(bucket["request_count"]) + 1
|
||||
bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0)
|
||||
bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0)
|
||||
if record.time_cost:
|
||||
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||
|
||||
result: List[ModelStatistics] = []
|
||||
for model_name, bucket in sorted(
|
||||
aggregates.items(),
|
||||
key=lambda item: float(item[1]["request_count"]),
|
||||
reverse=True,
|
||||
)[:10]:
|
||||
time_cost_count = int(bucket["time_cost_count"])
|
||||
avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=int(bucket["request_count"]),
|
||||
total_cost=float(bucket["total_cost"]),
|
||||
total_tokens=int(bucket["total_tokens"]),
|
||||
avg_response_time=avg_time_cost,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
hour_expr.label("hour"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(hour_expr)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
# 转换为字典以快速查找
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||
statement = (
|
||||
select(
|
||||
day_expr.label("day"),
|
||||
func.count().label("requests"),
|
||||
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||
)
|
||||
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||
.group_by(day_expr)
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
rows = session.exec(statement).all()
|
||||
|
||||
# 转换为字典
|
||||
data_dict = {row[0]: row for row in rows}
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=row[1] or 0,
|
||||
cost=float(row[2] or 0.0),
|
||||
tokens=row[3] or 0,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
with get_db_session() as session:
|
||||
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||
records = session.exec(statement).all()
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": record.total_tokens or 0,
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": None,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
"""获取统计摘要。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
return summary
|
||||
return await get_summary_statistics(start_time, now)
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计摘要失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -351,17 +36,10 @@ async def get_summary(hours: int = 24):
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
"""获取模型统计。"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
stats = await _get_model_statistics(start_time)
|
||||
return stats
|
||||
start_time = datetime.now() - timedelta(hours=hours)
|
||||
return await get_model_statistics(start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
@@ -4,13 +4,17 @@
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from importlib.metadata import PackageNotFoundError, version as get_package_version
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import httpx
|
||||
import os
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.webui.dependencies import require_auth
|
||||
@@ -20,6 +24,10 @@ logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
_DASHBOARD_PACKAGE_NAME = "maibot-dashboard"
|
||||
_PYPI_JSON_URL = f"https://pypi.org/pypi/{_DASHBOARD_PACKAGE_NAME}/json"
|
||||
_PYPI_CACHE_TTL_SECONDS = 60 * 60 * 6
|
||||
_pypi_version_cache: Dict[str, Any] = {"checked_at": 0.0, "latest_version": None}
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
@@ -38,6 +46,72 @@ class StatusResponse(BaseModel):
|
||||
start_time: str
|
||||
|
||||
|
||||
class DashboardVersionResponse(BaseModel):
|
||||
"""WebUI 版本检查响应"""
|
||||
|
||||
current_version: str
|
||||
latest_version: Optional[str] = None
|
||||
has_update: bool = False
|
||||
package_name: str = _DASHBOARD_PACKAGE_NAME
|
||||
pypi_url: str = f"https://pypi.org/project/{_DASHBOARD_PACKAGE_NAME}/"
|
||||
|
||||
|
||||
def _get_installed_dashboard_version() -> str:
|
||||
try:
|
||||
return get_package_version(_DASHBOARD_PACKAGE_NAME)
|
||||
except PackageNotFoundError:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _normalize_version(version: str) -> tuple[int, ...]:
|
||||
clean_version = version.strip().lower().removeprefix("v")
|
||||
numeric_part = clean_version.split("-", 1)[0].split("+", 1)[0]
|
||||
parts = []
|
||||
for item in numeric_part.split("."):
|
||||
number = ""
|
||||
for char in item:
|
||||
if not char.isdigit():
|
||||
break
|
||||
number += char
|
||||
parts.append(int(number) if number else 0)
|
||||
return tuple(parts)
|
||||
|
||||
|
||||
def _is_newer_version(latest_version: Optional[str], current_version: str) -> bool:
|
||||
if not latest_version or not current_version or current_version == "unknown":
|
||||
return False
|
||||
|
||||
latest_parts = _normalize_version(latest_version)
|
||||
current_parts = _normalize_version(current_version)
|
||||
width = max(len(latest_parts), len(current_parts))
|
||||
return latest_parts + (0,) * (width - len(latest_parts)) > current_parts + (0,) * (width - len(current_parts))
|
||||
|
||||
|
||||
async def _get_latest_dashboard_version_from_pypi() -> Optional[str]:
|
||||
now = time.time()
|
||||
cached_version = _pypi_version_cache.get("latest_version")
|
||||
checked_at = float(_pypi_version_cache.get("checked_at", 0.0))
|
||||
if cached_version and now - checked_at < _PYPI_CACHE_TTL_SECONDS:
|
||||
return str(cached_version)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(_PYPI_JSON_URL)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
except Exception as e:
|
||||
logger.debug(f"检查 WebUI PyPI 版本失败: {e}")
|
||||
return str(cached_version) if cached_version else None
|
||||
|
||||
latest_version = payload.get("info", {}).get("version")
|
||||
if isinstance(latest_version, str) and latest_version.strip():
|
||||
_pypi_version_cache["checked_at"] = now
|
||||
_pypi_version_cache["latest_version"] = latest_version.strip()
|
||||
return latest_version.strip()
|
||||
|
||||
return str(cached_version) if cached_version else None
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot():
|
||||
"""
|
||||
@@ -89,6 +163,19 @@ async def get_maibot_status():
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/dashboard-version", response_model=DashboardVersionResponse)
|
||||
async def get_dashboard_version(current_version: Optional[str] = None):
|
||||
"""获取 WebUI 当前版本和 PyPI 最新版本。"""
|
||||
resolved_current_version = current_version or _get_installed_dashboard_version()
|
||||
latest_version = await _get_latest_dashboard_version_from_pypi()
|
||||
|
||||
return DashboardVersionResponse(
|
||||
current_version=resolved_current_version,
|
||||
latest_version=latest_version,
|
||||
has_update=_is_newer_version(latest_version, resolved_current_version),
|
||||
)
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class UnifiedWebSocketManager:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 发送失败: connection=%s, error=%s", connection.connection_id, exc)
|
||||
logger.error(f"统一 WebSocket 发送失败: connection={connection.connection_id}, error={exc}")
|
||||
|
||||
async def connect(self, connection_id: str, websocket: WebSocket) -> WebSocketConnection:
|
||||
"""注册一个新的物理 WebSocket 连接。
|
||||
@@ -108,7 +108,7 @@ class UnifiedWebSocketManager:
|
||||
try:
|
||||
await self._close_websocket(connection)
|
||||
except Exception as exc:
|
||||
logger.debug("关闭统一 WebSocket 底层连接时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
logger.debug(f"关闭统一 WebSocket 底层连接时出现异常: connection={connection_id}, error={exc}")
|
||||
|
||||
await connection.send_queue.put(None)
|
||||
if connection.sender_task is not None:
|
||||
@@ -117,7 +117,7 @@ class UnifiedWebSocketManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("等待发送协程退出时出现异常: connection=%s, error=%s", connection_id, exc)
|
||||
logger.debug(f"等待发送协程退出时出现异常: connection={connection_id}, error={exc}")
|
||||
|
||||
def get_connection(self, connection_id: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接上下文。
|
||||
|
||||
@@ -544,7 +544,7 @@ async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(
|
||||
|
||||
connection_id = uuid.uuid4().hex
|
||||
await websocket_manager.connect(connection_id=connection_id, websocket=websocket)
|
||||
logger.info("统一 WebSocket 客户端已连接: connection=%s", connection_id)
|
||||
logger.info(f"统一 WebSocket 客户端已连接: connection={connection_id}")
|
||||
await websocket_manager.send_event(
|
||||
connection_id,
|
||||
domain="system",
|
||||
@@ -565,17 +565,15 @@ async def websocket_endpoint(websocket: WebSocket, token: Optional[str] = Query(
|
||||
continue
|
||||
await handle_client_message(connection_id, cast(Dict[str, Any], raw_message))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("统一 WebSocket 客户端已断开: connection=%s", connection_id)
|
||||
logger.info(f"统一 WebSocket 客户端已断开: connection={connection_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("统一 WebSocket 连接处理被取消: connection=%s", connection_id)
|
||||
logger.warning(f"统一 WebSocket 连接处理被取消: connection={connection_id}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("统一 WebSocket 处理失败: connection=%s, error=%s", connection_id, exc, exc_info=True)
|
||||
logger.error(f"统一 WebSocket 处理失败: connection={connection_id}, error={exc}", exc_info=True)
|
||||
finally:
|
||||
chat_manager.disconnect_connection(connection_id)
|
||||
await websocket_manager.disconnect(connection_id)
|
||||
logger.info(
|
||||
"统一 WebSocket 连接清理完成: connection=%s, 剩余连接=%s",
|
||||
connection_id,
|
||||
len(websocket_manager.connections),
|
||||
f"统一 WebSocket 连接清理完成: connection={connection_id}, 剩余连接={len(websocket_manager.connections)}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user