WebUI后端整体重构

This commit is contained in:
墨梓柒
2026-01-13 07:24:27 +08:00
parent 812296590e
commit ffafbf0a26
36 changed files with 927 additions and 294 deletions

View File

@@ -0,0 +1,35 @@
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
from fastapi import APIRouter
def get_api_router() -> APIRouter:
"""获取主 API 路由器(包含所有子路由)"""
from src.webui.routes import router as main_router
return main_router
def get_all_routers() -> list[APIRouter]:
"""获取所有需要独立注册的路由器列表"""
from src.webui.routes import router as main_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.chat import router as chat_router
from src.webui.api.planner import router as planner_router
from src.webui.api.replier import router as replier_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,
replier_router,
]
__all__ = [
"get_api_router",
"get_all_routers",
]

View File

@@ -0,0 +1,938 @@
"""麦麦 2025 年度总结 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import (
LLMUsage,
OnlineTime,
Messages,
ChatStreams,
PersonInfo,
Emoji,
Expression,
ActionRecords,
Jargon,
)
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ==================== Pydantic 模型定义 ====================
class TimeFootprintData(BaseModel):
"""时光足迹数据"""
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
first_message_time: Optional[str] = Field(None, description="初次消息时间")
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
is_night_owl: bool = Field(False, description="是否是夜猫子")
class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
longest_companion_days: int = Field(0, description="陪伴天数")
class BrainPowerData(BaseModel):
"""最强大脑数据"""
total_tokens: int = Field(0, description="年度消耗Token总量")
total_cost: float = Field(0.0, description="年度总花费")
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
total_actions: int = Field(0, description="总动作数")
no_reply_count: int = Field(0, description="选择沉默的次数")
avg_interest_value: float = Field(0.0, description="平均兴趣值")
max_interest_value: float = Field(0.0, description="最高兴趣值")
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
max_reasoning_length: int = Field(0, description="最长思考长度")
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
"""趣味成就数据"""
new_jargon_count: int = Field(0, description="新学到的黑话数量")
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
class AnnualReportData(BaseModel):
"""年度报告完整数据"""
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
achievements: AchievementData = Field(default_factory=AchievementData)
# ==================== 辅助函数 ====================
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
"""获取指定年份的时间戳范围"""
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
return start, end
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
"""获取指定年份的 datetime 范围"""
start = datetime(year, 1, 1, 0, 0, 0)
end = datetime(year, 12, 31, 23, 59, 59)
return start, end
# ==================== 维度一:时光足迹 ====================
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
"""获取时光足迹数据"""
data = TimeFootprintData()
start_ts, end_ts = get_year_time_range(year)
start_dt, end_dt = get_year_datetime_range(year)
try:
# 1. 年度在线时长
online_records = list(
OnlineTime.select().where(
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
)
)
total_seconds = 0
for record in online_records:
try:
start = max(record.start_timestamp, start_dt)
end = min(record.end_timestamp, end_dt)
if end > start:
total_seconds += (end - start).total_seconds()
except Exception:
continue
data.total_online_hours = round(total_seconds / 3600, 2)
# 2. 初次相遇 - 年度第一条消息
first_msg = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.order_by(Messages.time.asc())
.first()
)
if first_msg:
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
content = first_msg.processed_plain_text or first_msg.display_message or ""
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
# 3. 最忙碌的一天
# 使用 SQLite 的 date 函数按日期分组
busiest_query = (
Messages.select(
fn.date(Messages.time, "unixepoch").alias("day"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.date(Messages.time, "unixepoch"))
.order_by(fn.COUNT(Messages.id).desc())
.limit(1)
)
busiest_result = list(busiest_query.dicts())
if busiest_result:
data.busiest_day = busiest_result[0].get("day")
data.busiest_day_count = busiest_result[0].get("count", 0)
# 4. 昼夜节律 - 24小时活跃分布
hourly_query = (
Messages.select(
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
)
hourly_distribution = [0] * 24
for row in hourly_query.dicts():
try:
hour = int(row.get("hour", 0))
if 0 <= hour < 24:
hourly_distribution[hour] = row.get("count", 0)
except (ValueError, TypeError):
continue
data.hourly_distribution = hourly_distribution
# 5. 深夜食堂 (0-4点)
data.midnight_chat_count = sum(hourly_distribution[0:5])
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
morning_activity = sum(hourly_distribution[6:13])
data.is_night_owl = night_activity > morning_activity
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
return data
# ==================== 维度二:社交网络 ====================
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 加入的群组总数
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
# 2. 话痨群组 TOP3
top_groups_query = (
Messages.select(
Messages.chat_info_group_id,
Messages.chat_info_group_name,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.chat_info_group_id.is_null(False))
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_groups = [
{
"group_id": row["chat_info_group_id"],
"group_name": row["chat_info_group_name"] or "未知群组",
"message_count": row["count"],
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
}
for row in top_groups_query.dicts()
]
# 3. 互动最多的用户 TOP5过滤 bot 自身)
top_users_query = (
Messages.select(
Messages.user_id,
Messages.user_nickname,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id.is_null(False))
& (Messages.user_id != bot_qq) # 过滤 bot 自身
)
.group_by(Messages.user_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_users = [
{
"user_id": row["user_id"],
"user_nickname": row["user_nickname"] or "未知用户",
"message_count": row["count"],
"is_webui": str(row["user_id"]).startswith("webui_"),
}
for row in top_users_query.dicts()
]
# 4. 被@次数
data.at_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_at == True)
)
.count()
)
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_mentioned == True)
)
.count()
)
# 6. 最长情陪伴的用户(过滤 bot 自身)
companion_query = (
ChatStreams.select(
ChatStreams.user_id,
ChatStreams.user_nickname,
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
)
.where(
(ChatStreams.user_id.is_null(False))
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
)
companion_result = list(companion_query.dicts())
if companion_result:
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
duration = companion_result[0].get("duration", 0) or 0
data.longest_companion_days = int(duration / 86400) # 转换为天
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
return data
# ==================== 维度三:最强大脑 ====================
async def get_brain_power(year: int = 2025) -> BrainPowerData:
"""获取最强大脑数据"""
data = BrainPowerData()
start_dt, end_dt = get_year_datetime_range(year)
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 年度消耗 Token 总量和总花费
token_query = LLMUsage.select(
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
result = token_query.dicts().get()
data.total_tokens = int(result.get("total_tokens", 0) or 0)
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
# 2. 最爱用的模型
model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10)
)
model_results = list(model_query.dicts())
if model_results:
data.favorite_model = model_results[0].get("model")
data.favorite_model_count = model_results[0].get("count", 0)
data.model_distribution = [
{
"model": row["model"],
"count": row["count"],
"tokens": row["tokens"],
"cost": round(row["cost"], 4),
}
for row in model_results
]
# 3. 最昂贵的一次思考
expensive_query = (
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.order_by(LLMUsage.cost.desc())
.limit(1)
)
expensive_result = expensive_query.first()
if expensive_result:
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_query = (
LLMUsage.select(
LLMUsage.user_id,
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (LLMUsage.user_id != "system") # 过滤 system 用户
& (LLMUsage.user_id.is_null(False))
)
.group_by(LLMUsage.user_id)
.order_by(fn.SUM(LLMUsage.cost).desc())
.limit(3)
)
data.top_token_consumers = [
{
"user_id": row["user_id"],
"cost": round(row["cost"], 4),
"tokens": row["tokens"],
}
for row in consumer_query.dicts()
]
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (
LLMUsage.model_assign_name.contains("replyer")
| LLMUsage.model_assign_name.contains("回复")
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
)
)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where(
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
).count()
no_reply_count = ActionRecords.select().where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
).count()
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
# 6. 情绪波动 (兴趣值)
interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"),
).where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value.is_null(False))
)
interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
# 找到最高兴趣值的时间
if data.max_interest_value > 0:
max_interest_msg = (
Messages.select(Messages.time)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value == data.max_interest_value)
)
.first()
)
if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
"%Y-%m-%d %H:%M:%S"
)
# 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = (
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
)
reasoning_lengths = []
max_len = 0
max_len_time = None
for record in reasoning_records:
if record.action_reasoning:
length = len(record.action_reasoning)
reasoning_lengths.append(length)
if length > max_len:
max_len = length
max_len_time = record.time
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
if max_len_time:
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
return data
# ==================== 维度四:个性与表达 ====================
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 表情包之王 - 使用次数最多的表情包
top_emoji_query = (
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
.where(Emoji.is_registered == True)
.order_by(Emoji.usage_count.desc())
.limit(5)
)
top_emojis = list(top_emoji_query.dicts())
if top_emojis:
data.top_emoji = {
"id": top_emojis[0].get("id"),
"path": top_emojis[0].get("full_path"),
"description": top_emojis[0].get("description"),
"usage_count": top_emojis[0].get("usage_count", 0),
"hash": top_emojis[0].get("emoji_hash"),
}
data.top_emojis = [
{
"id": e.get("id"),
"path": e.get("full_path"),
"description": e.get("description"),
"usage_count": e.get("usage_count", 0),
"hash": e.get("emoji_hash"),
}
for e in top_emojis
]
# 2. 百变麦麦 - 最常用的表达风格
expression_query = (
Expression.select(
Expression.style,
fn.SUM(Expression.count).alias("total_count"),
)
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc())
.limit(5)
)
data.top_expressions = [
{"style": row["style"], "count": row["total_count"]}
for row in expression_query.dicts()
]
# 3. 被拒绝的表达
data.rejected_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.rejected == True)
)
.count()
)
# 4. 已检查的表达
data.checked_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.checked == True)
)
.count()
)
# 5. 表达总数
data.total_expressions = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.count()
)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
]
action_query = (
ActionRecords.select(
ActionRecords.action_name,
fn.COUNT(ActionRecords.id).alias("count"),
)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name.not_in(excluded_actions))
)
.group_by(ActionRecords.action_name)
.order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10)
)
data.action_types = [
{"action": row["action_name"], "count": row["count"]}
for row in action_query.dicts()
]
# 7. 处理的图片数量
data.image_processed_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_picid == True)
)
.count()
)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content)
# 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
# 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
Messages.time,
Messages.processed_plain_text,
Messages.display_message,
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = datetime.fromtimestamp(msg.time)
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
})
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
data.late_night_reply = {
"time": selected["datetime_str"],
"content": content,
}
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
reply_records = (
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
)
reply_contents = []
for record in reply_records:
try:
action_data = record.action_data
if action_data:
content = None
# 尝试解析 JSON 格式
try:
parsed = json_lib.loads(action_data)
if isinstance(parsed, dict):
# 优先使用 reply_text其次使用 content
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
if most_common:
fav_content, fav_count = most_common[0]
# 截断过长的内容
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
data.favorite_reply = {
"content": display_content,
"count": fav_count,
}
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
return data
# ==================== 维度五:趣味成就 ====================
async def get_achievements(year: int = 2025) -> AchievementData:
"""获取趣味成就数据"""
data = AchievementData()
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
# 2. 代表性黑话示例
jargon_samples = (
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
.where(Jargon.is_jargon == True)
.order_by(Jargon.count.desc())
.limit(5)
)
data.sample_jargons = [
{
"content": j.content,
"meaning": j.meaning,
"count": j.count,
}
for j in jargon_samples
]
# 3. 总消息数
data.total_messages = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.count()
)
# 4. 总回复数 (有 reply_to 的消息)
data.total_replies = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.reply_to.is_null(False))
)
.count()
)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
return data
# ==================== API 路由 ====================
@router.get("/full", response_model=AnnualReportData)
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
"""
获取完整年度报告数据
Args:
year: 报告年份默认2025
Returns:
完整的年度报告数据
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"
# 并行获取各维度数据
time_footprint = await get_time_footprint(year)
social_network = await get_social_network(year)
brain_power = await get_brain_power(year)
expression_vibe = await get_expression_vibe(year)
achievements = await get_achievements(year)
report = AnnualReportData(
year=year,
bot_name=bot_name,
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
time_footprint=time_footprint,
social_network=social_network,
brain_power=brain_power,
expression_vibe=expression_vibe,
achievements=achievements,
)
logger.info(f"{year} 年度报告生成完成")
return report
except Exception as e:
logger.error(f"生成年度报告失败: {e}")
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
@router.get("/time-footprint", response_model=TimeFootprintData)
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取时光足迹数据"""
try:
return await get_time_footprint(year)
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/social-network", response_model=SocialNetworkData)
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取社交网络数据"""
try:
return await get_social_network(year)
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/brain-power", response_model=BrainPowerData)
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取最强大脑数据"""
try:
return await get_brain_power(year)
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/expression-vibe", response_model=ExpressionVibeData)
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取个性与表达数据"""
try:
return await get_expression_vibe(year)
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/achievements", response_model=AchievementData)
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取趣味成就数据"""
try:
return await get_achievements(year)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

781
src/webui/routers/chat.py Normal file
View File

@@ -0,0 +1,781 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话
支持两种模式:
1. WebUI 模式:使用 WebUI 平台独立身份聊天
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
"""
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
# 虚拟身份模式的群 ID 前缀
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# 固定的 WebUI 用户 ID 前缀
WEBUI_USER_ID_PREFIX = "webui_user_"
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False # 是否启用虚拟身份模式
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
person_id: Optional[str] = None # PersonInfo 的 person_id
user_id: Optional[str] = None # 原始平台用户 ID
user_nickname: Optional[str] = None # 用户昵称
group_id: Optional[str] = None # 虚拟群 ID自动生成或用户指定
group_name: Optional[str] = None # 虚拟群名(用户自定义)
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False
class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200):
self.max_messages = max_messages
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: 群 ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
# 虚拟群user_id 等于机器人 QQ 账号的是机器人消息
bot_qq = str(global_config.bot.qq_account)
is_bot = user_id == bot_qq
else:
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.time,
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录
Args:
limit: 获取的消息数量
group_id: 群 ID默认为 WEBUI_CHAT_GROUP_ID
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
# 查询指定群的消息,按时间排序
messages = (
Messages.select()
.where(Messages.chat_info_group_id == target_group_id)
.order_by(Messages.time.desc())
.limit(limit)
)
# 转换为列表并反转(使最旧的消息在前)
# 传递 group_id 以便正确判断虚拟群中的机器人消息
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空聊天历史记录
Args:
group_id: 群 ID默认清空 WebUI 默认聊天室
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
return 0
# 全局聊天历史管理器
chat_history = ChatHistoryManager()
# 存储 WebSocket 连接
class ChatConnectionManager:
"""聊天连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
chat_manager = ChatConnectionManager()
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据
Args:
content: 消息内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID可选自动生成
is_at_bot: 是否 @ 机器人
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
"""
if message_id is None:
message_id = str(uuid.uuid4())
# 确定使用的平台、群信息和用户信息
if virtual_config and virtual_config.enabled:
# 虚拟身份模式:使用真实平台身份
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
group_name = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
actual_user_name = virtual_config.user_nickname or user_name
else:
# 标准 WebUI 模式
platform = WEBUI_CHAT_PLATFORM
group_id = WEBUI_CHAT_GROUP_ID
group_name = "WebUI本地聊天室"
actual_user_id = user_id
actual_user_name = user_name
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": actual_user_id,
"user_nickname": actual_user_name,
"user_cardname": actual_user_name,
"platform": platform,
},
"additional_config": {
"at_bot": is_at_bot,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
{
"type": "mention_bot",
"data": "1.0",
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
如果指定了 group_id则获取该虚拟群的历史记录
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
history = chat_history.get_history(limit, target_group_id)
return {
"success": True,
"messages": history,
"total": len(history),
}
@router.get("/platforms")
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
从 PersonInfo 表中获取所有已知的平台
"""
try:
from peewee import fn
# 查询所有不同的平台
platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
.group_by(PersonInfo.platform)
.order_by(fn.COUNT(PersonInfo.id).desc())
)
result = []
for p in platforms:
if p.platform: # 排除空平台
result.append({"platform": p.platform, "count": p.count})
return {"success": True, "platforms": result}
except Exception as e:
logger.error(f"获取平台列表失败: {e}")
return {"success": False, "error": str(e), "platforms": []}
@router.get("/persons")
async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
Args:
platform: 平台名称(如 qq, discord 等)
search: 搜索关键词匹配昵称、用户名、user_id
limit: 返回数量限制
"""
try:
# 构建查询
query = PersonInfo.select().where(PersonInfo.platform == platform)
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 按最后交互时间排序,优先显示活跃用户
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.limit(limit)
result = []
for person in query:
result.append(
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.nickname or person.user_id,
}
)
return {"success": True, "persons": result, "total": len(result)}
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
return {"success": False, "error": str(e), "persons": []}
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
group_id: 可选,指定要清空的群 ID默认清空 WebUI 默认聊天室
"""
deleted = chat_history.clear_history(group_id)
return {
"success": True,
"message": f"已清空 {deleted} 条聊天记录",
}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
Args:
user_id: 用户唯一标识(由前端生成并持久化)
user_name: 用户显示昵称(可修改)
platform: 虚拟身份模式的平台(可选)
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名(可选)
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取)
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的
if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
# 当前会话的虚拟身份配置(可通过消息动态更新)
current_virtual_config: Optional[VirtualIdentityConfig] = None
# 如果 URL 参数中提供了虚拟身份信息,自动配置
if platform and person_id:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
# 使用前端传递的 group_id如果没有则生成一个稳定的
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
await chat_manager.connect(websocket, session_id, user_id)
try:
# 构建会话信息
session_info_data = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
# 如果有虚拟身份配置,添加到会话信息中
if current_virtual_config and current_virtual_config.enabled:
session_info_data["virtual_mode"] = True
session_info_data["group_id"] = current_virtual_config.group_id
session_info_data["virtual_identity"] = {
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_name": current_virtual_config.group_name,
}
# 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, session_info_data)
# 发送历史记录(根据模式选择不同的群)
if current_virtual_config and current_virtual_config.enabled:
history = chat_history.get_history(50, current_virtual_config.group_id)
else:
history = chat_history.get_history(50)
if history:
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史)
if current_virtual_config and current_virtual_config.enabled:
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
else:
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": welcome_msg,
"timestamp": time.time(),
},
)
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
content = data.get("content", "").strip()
if not content:
continue
# 用户可以更新昵称
current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4())
timestamp = time.time()
# 确定发送者信息(根据是否使用虚拟身份)
if current_virtual_config and current_virtual_config.enabled:
sender_name = current_virtual_config.user_nickname or current_user_name
sender_user_id = current_virtual_config.user_id or user_id
else:
sender_name = current_user_name
sender_user_id = user_id
# 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": sender_name,
"user_id": sender_user_id,
"is_bot": False,
},
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
}
)
# 创建麦麦消息格式
message_data = create_message_data(
content=content,
user_id=user_id,
user_name=current_user_name,
message_id=message_id,
is_at_bot=True,
virtual_config=current_virtual_config,
)
try:
# 显示正在输入状态
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": True,
}
)
# 调用麦麦的消息处理
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally:
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": False,
}
)
elif data.get("type") == "ping":
await chat_manager.send_message(
session_id,
{
"type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname":
# 允许用户更新昵称
if new_name := data.get("user_name", "").strip():
current_user_name = new_name
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
},
)
elif data.get("type") == "set_virtual_identity":
# 设置或更新虚拟身份配置
virtual_data = data.get("config", {})
if virtual_data.get("enabled"):
# 验证必要字段
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
"timestamp": time.time(),
},
)
continue
# 获取用户信息
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
if not person:
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"找不到用户: {virtual_data.get('person_id')}",
"timestamp": time.time(),
},
)
continue
# 生成虚拟群 ID
custom_group_id = virtual_data.get("group_id")
if custom_group_id:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
else:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=group_id,
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
)
# 发送虚拟身份已激活的消息
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {
"enabled": True,
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_id": current_virtual_config.group_id,
"group_name": current_virtual_config.group_name,
},
"timestamp": time.time(),
},
)
# 加载虚拟群的历史记录
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": virtual_history,
"group_id": current_virtual_config.group_id,
},
)
# 发送系统消息
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
"timestamp": time.time(),
},
)
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"设置虚拟身份失败: {str(e)}",
"timestamp": time.time(),
},
)
else:
# 禁用虚拟身份模式
current_virtual_config = None
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {"enabled": False},
"timestamp": time.time(),
},
)
# 重新加载默认聊天室历史
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": default_history,
"group_id": WEBUI_CHAT_GROUP_ID,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": "已切换回 WebUI 独立用户模式",
"timestamp": time.time(),
},
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, user_id)
@router.get("/info")
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
def get_webui_chat_broadcaster() -> tuple:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
"""
return (chat_manager, WEBUI_CHAT_PLATFORM)

597
src/webui/routers/config.py Normal file
View File

@@ -0,0 +1,597 @@
"""
配置管理API路由
"""
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
VoiceConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
schema = ConfigSchemaGenerator.generate_config_schema(Config)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model")
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 =====
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
支持的section_name:
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
- api_provider: APIProvider
- model_info: ModelInfo
"""
section_map = {
"bot": BotConfig,
"personality": PersonalityConfig,
"relationship": RelationshipConfig,
"chat": ChatConfig,
"message_receive": MessageReceiveConfig,
"emoji": EmojiConfig,
"expression": ExpressionConfig,
"keyword_reaction": KeywordReactionConfig,
"chinese_typo": ChineseTypoConfig,
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"voice": VoiceConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
}
if section_name not in section_map:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
try:
config_class = section_map[section_name]
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 =====
@router.get("/bot")
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model")
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("麦麦主程序配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model")
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("模型配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组类型(如 platforms, aliases直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 =====
@router.get("/bot/raw")
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
raw_content = f.read()
return {"success": True, "content": raw_content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
f.write(raw_content)
logger.info("麦麦主程序配置已更新(原始模式)")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}")
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 =====
def _normalize_adapter_path(path: str) -> str:
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
if not path:
return path
# 如果已经是绝对路径,直接返回
if os.path.isabs(path):
return path
# 相对路径,转换为相对于项目根目录的绝对路径
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
return path
try:
# 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith(".."):
return rel_path
except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
pass
# 无法转换为相对路径,返回绝对路径
return path
@router.get("/adapter-config/path")
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
webui_data_path = os.path.join("data", "webui.json")
if not os.path.exists(webui_data_path):
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
adapter_config_path = webui_data.get("adapter_config_path")
if not adapter_config_path:
return {"success": True, "path": None}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
import datetime
mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能)
display_path = _to_relative_path(abs_path)
return {"success": True, "path": display_path, "lastModified": last_modified}
else:
# 文件不存在,返回原路径
return {"success": True, "path": adapter_config_path, "lastModified": None}
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
# 保存到 data/webui.json
webui_data_path = os.path.join("data", "webui.json")
import json
# 读取现有数据
if os.path.exists(webui_data_path):
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
else:
webui_data = {}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
# 更新路径
webui_data["adapter_config_path"] = save_path
# 保存
os.makedirs("data", exist_ok=True)
with open(webui_data_path, "w", encoding="utf-8") as f:
json.dump(webui_data, f, ensure_ascii=False, indent=2)
logger.info(f"适配器配置路径已保存: {save_path}(绝对路径: {abs_path}")
return {"success": True, "message": "路径已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config")
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
logger.info(f"已读取适配器配置: {path} (绝对路径: {abs_path})")
return {"success": True, "content": content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
content = data.get("content")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
tomlkit.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在
dir_path = os.path.dirname(abs_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
# 保存文件
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"适配器配置已保存: {path} (绝对路径: {abs_path})")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

1310
src/webui/routers/emoji.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,773 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from src.webui.core import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
last_active_time: float
chat_id: str
create_date: Optional[float]
checked: bool
rejected: bool
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
checked: Optional[bool] = None
rejected: Optional[bool] = None
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
return ExpressionResponse(
id=expression.id,
situation=expression.situation,
style=expression.style,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表(用于下拉选择)
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
)
)
# 按名称排序
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(maibot_session, authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and expression.checked:
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段
update_data.pop("require_unchecked", None)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if "checked" in update_data or "rejected" in update_data:
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
rejected: int
ai_checked: int
user_checked: int
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取审核统计数据
Returns:
审核统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
user_checked = Expression.select().where(Expression.modified_by == "user").count()
return ReviewStatsResponse(
total=total,
unchecked=unchecked,
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
@router.get("/review/list", response_model=ReviewListResponse)
async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取待审核/已审核的表达方式列表
Args:
page: 页码
page_size: 每页数量
filter_type: 筛选类型 (unchecked/passed/rejected/all)
search: 搜索关键词
chat_id: 聊天ID筛选
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
query = Expression.select()
# 根据筛选类型过滤
if filter_type == "unchecked":
query = query.where(Expression.checked == False)
elif filter_type == "passed":
query = query.where((Expression.checked == True) & (Expression.rejected == False))
elif filter_type == "rejected":
query = query.where((Expression.checked == True) & (Expression.rejected == True))
# all 不需要额外过滤
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序:创建时间倒序
from peewee import Case
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
total = query.count()
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
return ReviewListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=[expression_to_response(expr) for expr in expressions],
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
failed: int
results: List[BatchReviewResultItem]
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量审核表达方式
Args:
request: 批量审核请求
Returns:
批量审核结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
results = []
succeeded = 0
failed = 0
for item in request.items:
try:
expression = Expression.get_or_none(Expression.id == item.id)
if not expression:
results.append(
BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式")
)
failed += 1
continue
# 冲突检测
if item.require_unchecked and expression.checked:
results.append(
BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
)
)
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = "user"
expression.last_active_time = time.time()
expression.save()
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
)
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e)))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量审核失败: {e}")
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e

532
src/webui/routers/jargon.py Normal file
View File

@@ -0,0 +1,532 @@
"""黑话(俚语)管理路由"""
import json
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon")
router = APIRouter(prefix="/jargon", tags=["Jargon"])
# ==================== 辅助函数 ====================
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
解析 chat_id 字段,提取所有 stream_id
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
if isinstance(parsed, list):
# 格式: [["stream_id", user_id], ...]
stream_ids = []
for item in parsed:
if isinstance(item, list) and len(item) >= 1:
stream_ids.append(str(item[0]))
return stream_ids
else:
# 其他格式,返回原始字符串
return [chat_id_str]
except (json.JSONDecodeError, TypeError):
# 不是有效的 JSON可能是直接的 stream_id
return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str:
"""
获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
# ==================== 请求/响应模型 ====================
class JargonResponse(BaseModel):
"""黑话信息响应"""
id: int
content: str
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0
is_jargon: Optional[bool] = None
is_complete: bool = False
inference_with_context: Optional[str] = None
inference_content_only: Optional[str] = None
class JargonListResponse(BaseModel):
"""黑话列表响应"""
success: bool = True
total: int
page: int
page_size: int
data: List[JargonResponse]
class JargonDetailResponse(BaseModel):
"""黑话详情响应"""
success: bool = True
data: JargonResponse
class JargonCreateRequest(BaseModel):
"""黑话创建请求"""
content: str = Field(..., description="黑话内容")
raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel):
"""黑话更新请求"""
content: Optional[str] = None
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None
class JargonCreateResponse(BaseModel):
"""黑话创建响应"""
success: bool = True
message: str
data: JargonResponse
class JargonUpdateResponse(BaseModel):
"""黑话更新响应"""
success: bool = True
message: str
data: Optional[JargonResponse] = None
class JargonDeleteResponse(BaseModel):
"""黑话删除响应"""
success: bool = True
message: str
deleted_count: int = 0
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int] = Field(..., description="要删除的黑话ID列表")
class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict
class ChatInfoResponse(BaseModel):
"""聊天信息响应"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool = True
data: List[ChatInfoResponse]
# ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict:
"""将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": jargon.chat_id,
"stream_id": stream_id,
"chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only,
}
# ==================== API 端点 ====================
@router.get("/list", response_model=JargonListResponse)
async def get_jargon_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
):
"""获取黑话列表"""
try:
# 构建查询
query = Jargon.select()
# 搜索过滤
if search:
query = query.where(
(Jargon.content.contains(search))
| (Jargon.meaning.contains(search))
| (Jargon.raw_content.contains(search))
)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
# 使用第一个 stream_id 进行模糊匹配
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
else:
# 如果无法解析,使用精确匹配
query = query.where(Jargon.chat_id == chat_id)
# 按是否是黑话筛选
if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon)
# 按是否全局筛选
if is_global is not None:
query = query.where(Jargon.is_global == is_global)
# 获取总数
total = query.count()
# 分页和排序(按使用次数降序)
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
return JargonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data,
)
except Exception as e:
logger.error(f"获取黑话列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
seen_stream_ids.add(stream_ids[0])
result = []
for stream_id in seen_stream_ids:
# 尝试从 ChatStreams 表获取聊天名称
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id方便筛选匹配
chat_name=chat_stream.group_name or stream_id,
platform=chat_stream.platform,
is_group=True,
)
)
else:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
platform=None,
is_group=False,
)
)
return ChatListResponse(success=True, data=result)
except Exception as e:
logger.error(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/stats/summary", response_model=JargonStatsResponse)
async def get_jargon_stats():
"""获取黑话统计数据"""
try:
# 总数量
total = Jargon.select().count()
# 已确认是黑话的数量
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
# 已确认不是黑话的数量
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
# 未判定的数量
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
# 全局黑话数量
global_count = Jargon.select().where(Jargon.is_global).count()
# 已完成推断的数量
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
return JargonStatsResponse(
success=True,
data={
"total": total,
"confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon,
"pending": pending,
"global_count": global_count,
"complete_count": complete_count,
"chat_count": chat_count,
"top_chats": top_chats_dict,
},
)
except Exception as e:
logger.error(f"获取黑话统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
except HTTPException:
raise
except Exception as e:
logger.error(f"获取黑话详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
@router.post("/", response_model=JargonCreateResponse)
async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话
jargon = Jargon.create(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning,
chat_id=request.chat_id,
is_global=request.is_global,
count=0,
is_jargon=None,
is_complete=False,
)
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
return JargonCreateResponse(
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"创建黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
jargon.save()
logger.info(f"更新黑话成功: id={jargon_id}")
return JargonUpdateResponse(
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"更新黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
async def delete_jargon(jargon_id: int):
"""删除黑话"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content
jargon.delete_instance()
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse(
success=True,
message="删除成功",
deleted_count=1,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
@router.post("/batch/delete", response_model=JargonDeleteResponse)
async def batch_delete_jargons(request: BatchDeleteRequest):
"""批量删除黑话"""
try:
if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse(
success=True,
message=f"成功删除 {deleted_count} 条黑话",
deleted_count=deleted_count,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
async def batch_set_jargon_status(
ids: Annotated[List[int], Query(description="黑话ID列表")],
is_jargon: Annotated[bool, Query(description="是否是黑话")],
):
"""批量设置黑话状态"""
try:
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse(
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量更新黑话状态失败: {e}")
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e

View File

@@ -0,0 +1,390 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.config.config import global_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
# 延迟初始化的轻量级 embedding store只读仅用于获取段落完整文本
_paragraph_store_cache = None
def _get_paragraph_store():
"""延迟加载段落 embedding store只读模式轻量级
Returns:
EmbeddingStore | None: 如果配置启用则返回store否则返回None
"""
# 检查配置是否启用
if not global_config.webui.enable_paragraph_content:
return None
global _paragraph_store_cache
if _paragraph_store_cache is not None:
return _paragraph_store_cache
try:
from src.chat.knowledge.embedding_store import EmbeddingStore
import os
# 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore(
namespace="paragraph",
dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程
chunk_size=100
)
paragraph_store.load_from_file()
_paragraph_store_cache = paragraph_store
logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落")
return paragraph_store
except Exception as e:
logger.warning(f"加载段落 embedding store 失败: {e}")
return None
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容
Args:
node_id: 段落节点ID格式为 'paragraph-{hash}'
Returns:
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
"""
try:
paragraph_store = _get_paragraph_store()
if paragraph_store is None:
# 功能未启用
return None, False
# 从 store 中获取完整内容
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '')
if content:
return content, True
return None, True
except Exception as e:
logger.debug(f"获取段落内容失败: {e}")
return None, True
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
id: str
type: str # 'entity' or 'paragraph'
content: str
create_time: Optional[float] = None
class KnowledgeEdge(BaseModel):
"""知识边"""
source: str
target: str
weight: float
create_time: Optional[float] = None
update_time: Optional[float] = None
class KnowledgeGraph(BaseModel):
"""知识图谱"""
nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel):
"""知识库统计信息"""
total_nodes: int
total_edges: int
entity_nodes: int
paragraph_nodes: int
avg_connections: float
def _load_kg_manager():
"""延迟加载 KGManager"""
try:
from src.chat.knowledge.kg_manager import KGManager
kg_manager = KGManager()
kg_manager.load_from_file()
return kg_manager
except Exception as e:
logger.error(f"加载 KGManager 失败: {e}")
return None
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
"""将 DiGraph 转换为 JSON 格式"""
if kg_manager is None or kg_manager.graph is None:
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
nodes = []
edges = []
# 转换节点
node_list = graph.get_node_list()
for node_id in node_list:
try:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 转换边
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
# edge_tuple 是 (source, target) 元组
source, target = edge_tuple[0], edge_tuple[1]
# 通过 graph[source, target] 获取边的属性数据
edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
return KnowledgeGraph(nodes=nodes, edges=edges)
@router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
Args:
limit: 返回的最大节点数,默认 100,最大 10000
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
Returns:
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None:
logger.warning("KGManager 未初始化,返回空图谱")
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
all_node_list = graph.get_node_list()
# 按类型过滤节点
if node_type == "entity":
all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量
total_nodes = len(all_node_list)
if len(all_node_list) > limit:
node_list = all_node_list[:limit]
else:
node_list = all_node_list
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
# 转换节点
nodes = []
node_ids = set()
for node_id in node_list:
try:
node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
node_ids.add(node_id)
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 只获取涉及当前节点集的边(保证图的完整性)
edges = []
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
source, target = edge_tuple[0], edge_tuple[1]
# 只包含两端都在当前节点集中的边
if source not in node_ids or target not in node_ids:
continue
edge_data = graph[source, target]
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
return graph_data
except Exception as e:
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
return KnowledgeGraph(nodes=[], edges=[])
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
KnowledgeStats: 统计信息
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
graph = kg_manager.graph
node_list = graph.get_node_list()
edge_list = graph.get_edge_list()
total_nodes = len(node_list)
total_edges = len(edge_list)
# 统计节点类型
entity_nodes = 0
paragraph_nodes = 0
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == "ent":
entity_nodes += 1
elif node_type == "pg":
paragraph_nodes += 1
except Exception:
continue
# 计算平均连接数
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
return KnowledgeStats(
total_nodes=total_nodes,
total_edges=total_edges,
entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2),
)
except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:
query: 搜索关键词
Returns:
List[KnowledgeNode]: 匹配的节点列表
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return []
graph = kg_manager.graph
node_list = graph.get_node_list()
results = []
query_lower = query.lower()
# 在节点内容中搜索
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception:
continue
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
return results[:50] # 限制返回数量
except Exception as e:
logger.error(f"搜索节点失败: {e}", exc_info=True)
return []

View File

383
src/webui/routers/model.py Normal file
View File

@@ -0,0 +1,383 @@
"""
模型列表获取API路由
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
"""
import os
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
"openai": {
"endpoint": "/models",
"parser": "openai",
},
# Gemini 格式
"gemini": {
"endpoint": "/models",
"parser": "gemini",
},
}
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
if not url:
return ""
return url.rstrip("/")
def _parse_openai_response(data: dict) -> list[dict]:
"""
解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append(
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
)
return models
def _parse_gemini_response(data: dict) -> list[dict]:
"""
解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
"""
models = []
if "models" in data and isinstance(data["models"], list):
for model in data["models"]:
if isinstance(model, dict) and "name" in model:
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
model_id = model["name"]
if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀
models.append(
{
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
}
)
return models
async def _fetch_models_from_provider(
base_url: str,
api_key: str,
endpoint: str,
parser: str,
client_type: str = "openai",
) -> list[dict]:
"""
从提供商 API 获取模型列表
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else:
raise HTTPException(
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
) from e
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应
if parser == "openai":
return _parse_openai_response(data)
elif parser == "gemini":
return _parse_gemini_response(data)
else:
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
def _get_provider_config(provider_name: str) -> Optional[dict]:
"""
从 model_config.toml 获取指定提供商的配置
Args:
provider_name: 提供商名称
Returns:
提供商配置,如果未找到则返回 None
"""
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
return None
@router.get("/list")
async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
通过提供商名称查找配置,然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"provider": provider_name,
"count": len(models),
}
@router.get("/list-by-url")
async def get_models_by_url(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: str = Query(..., description="API Key"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表(用于自定义提供商)
"""
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"count": len(models),
}
@router.get("/test-connection")
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
分两步测试:
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
2. API Key 验证(可选):如果提供了 api_key尝试获取模型列表验证 Key 是否有效
返回:
- network_ok: 网络是否连通
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
- latency_ms: 响应延迟(毫秒)
- error: 错误信息(如果有)
"""
import time
base_url = _normalize_url(base_url)
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
result = {
"network_ok": False,
"api_key_valid": None,
"latency_ms": None,
"error": None,
"http_status": None,
}
# 第一步:测试网络连通性
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
# 尝试 GET 请求 base_url不需要 API Key
response = await client.get(base_url)
latency = (time.time() - start_time) * 1000
result["network_ok"] = True
result["latency_ms"] = round(latency, 2)
result["http_status"] = response.status_code
except httpx.ConnectError as e:
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
return result
except httpx.TimeoutException:
result["error"] = "连接超时:服务器响应时间过长"
return result
except httpx.RequestError as e:
result["error"] = f"请求错误:{str(e)}"
return result
except Exception as e:
result["error"] = f"未知错误:{str(e)}"
return result
# 第二步:如果提供了 API Key验证其有效性
if api_key:
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# 尝试获取模型列表
models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers)
if response.status_code == 200:
result["api_key_valid"] = True
elif response.status_code in (401, 403):
result["api_key_valid"] = False
result["error"] = "API Key 无效或已过期"
else:
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
result["api_key_valid"] = None
except Exception as e:
# API Key 验证失败不影响网络连通性结果
logger.warning(f"API Key 验证失败: {e}")
result["api_key_valid"] = None
return result
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接(从配置文件读取信息)
"""
# 读取配置文件
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(model_config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(model_config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 查找提供商
providers = config.get("api_providers", [])
provider = None
for p in providers:
if p.get("name") == provider_name:
provider = p
break
if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)

416
src/webui/routers/person.py Normal file
View File

@@ -0,0 +1,416 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header
import json
import time
logger = get_logger("webui.person")
# 创建路由器
router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
person_name: Optional[str]
name_reason: Optional[str]
platform: str
user_id: str
nickname: Optional[str]
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
memory_points: Optional[str]
know_times: Optional[float]
know_since: Optional[float]
last_know: Optional[float]
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[PersonInfoResponse]
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
memory_points: Optional[str] = None
is_known: Optional[bool] = None
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
person_ids: List[str]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[str] = []
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
"""解析群昵称 JSON 字符串"""
if not group_nick_name_str:
return None
try:
return json.loads(group_nick_name_str)
except (json.JSONDecodeError, TypeError):
return None
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
"""将 PersonInfo 模型转换为响应对象"""
return PersonInfoResponse(
id=person.id,
is_known=person.is_known,
person_id=person.person_id,
person_name=person.person_name,
name_reason=person.name_reason,
platform=person.platform,
user_id=person.user_id,
nickname=person.nickname,
group_nick_name=parse_group_nick_name(person.group_nick_name),
memory_points=person.memory_points,
know_times=person.know_times,
know_since=person.know_since,
last_know=person.last_know,
)
@router.get("/list", response_model=PersonListResponse)
async def get_person_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 person_name, nickname, user_id)
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息(只更新提供的字段)
Args:
person_id: 人物唯一 ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息
Args:
request: 包含person_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.person_ids:
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for person_id in request.person_ids:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
person.delete_instance()
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
except Exception as e:
logger.error(f"删除 {person_id} 失败: {e}")
failed_count += 1
failed_ids.append(person_id)
message = f"成功删除 {deleted_count} 个人物"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

2059
src/webui/routers/plugin.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,319 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据(优化:使用数据库聚合)"""
summary = StatisticsSummary()
# 使用聚合查询替代全量加载
query = LLMUsage.select(
fn.COUNT(LLMUsage.id).alias("total_requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
result = query.dicts().get()
summary.total_requests = result["total_requests"]
summary.total_cost = result["total_cost"]
summary.total_tokens = result["total_tokens"]
summary.avg_response_time = result["avg_response_time"] or 0.0
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
online_records = list(
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量 - 使用聚合优化
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
)
summary.total_messages = messages_query.scalar() or 0
# 统计回复数量
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp())
& (Messages.time <= end_time.timestamp())
& (Messages.reply_to.is_null(False))
)
summary.total_replies = replies_query.scalar() or 0
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
fn.COUNT(LLMUsage.id).alias("request_count"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
)
.where(LLMUsage.timestamp >= start_time)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10) # 只取前10个
)
result = []
for row in query.dicts():
result.append(
ModelStatistics(
model_name=row["model_name"],
request_count=row["request_count"],
total_cost=row["total_cost"],
total_tokens=row["total_tokens"],
avg_response_time=row["avg_response_time"] or 0.0,
)
)
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
)
# 转换为字典以快速查找
data_dict = {row["hour"]: row for row in query.dicts()}
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
if hour_str in data_dict:
row = data_dict[hour_str]
result.append(
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
)
# 转换为字典
data_dict = {row["day"]: row for row in query.dicts()}
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.strftime("%Y-%m-%dT00:00:00")
if day_str in data_dict:
row = data_dict[day_str]
result.append(
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities
@router.get("/summary")
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
Args:
hours: 统计时间范围(小时)
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
summary = await _get_summary_statistics(start_time, now)
return summary
except Exception as e:
logger.error(f"获取统计摘要失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models")
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计
Args:
hours: 统计时间范围(小时)
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
stats = await _get_model_statistics(start_time)
return stats
except Exception as e:
logger.error(f"获取模型统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")

View File

@@ -0,0 +1,9 @@
from .logs import router as logs_router
from .plugin_progress import get_progress_router
from .auth import router as ws_auth_router
__all__ = [
"logs_router",
"get_progress_router",
"ws_auth_router",
]

View File

@@ -0,0 +1,114 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.core import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie 或 Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证。
临时 token 有效期 60 秒,且只能使用一次。
注意:在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面。
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}

View File

@@ -0,0 +1,177 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@@ -0,0 +1,164 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
# 创建路由器
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
# 当前加载进度状态
current_progress: Dict[str, Any] = {
"operation": "idle", # idle, fetch, install, uninstall, update
"stage": "idle", # idle, loading, success, error
"progress": 0, # 0-100
"message": "",
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0,
}
async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
async def update_progress(
stage: str,
progress: int,
message: str,
operation: str = "fetch",
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
message: 当前消息
operation: 操作类型 (fetch, install, uninstall, update)
error: 错误信息(可选)
plugin_id: 当前操作的插件 ID
total_plugins: 总插件数
loaded_plugins: 已加载插件数
"""
progress_data = {
"operation": operation,
"stage": stage,
"progress": progress,
"message": message,
"error": error,
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式(按优先级):
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取插件进度 WebSocket 路由器"""
return router