remove webui
This commit is contained in:
@@ -1 +0,0 @@
|
||||
"""WebUI 模块"""
|
||||
@@ -1,938 +0,0 @@
|
||||
"""麦麦 2025 年度总结 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import (
|
||||
LLMUsage,
|
||||
OnlineTime,
|
||||
Messages,
|
||||
ChatStreams,
|
||||
PersonInfo,
|
||||
Emoji,
|
||||
Expression,
|
||||
ActionRecords,
|
||||
Jargon,
|
||||
)
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.annual_report")
|
||||
|
||||
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ==================== Pydantic 模型定义 ====================
|
||||
|
||||
|
||||
class TimeFootprintData(BaseModel):
|
||||
"""时光足迹数据"""
|
||||
|
||||
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
|
||||
first_message_time: Optional[str] = Field(None, description="初次消息时间")
|
||||
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
|
||||
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
|
||||
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
|
||||
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
|
||||
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
|
||||
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
|
||||
is_night_owl: bool = Field(False, description="是否是夜猫子")
|
||||
|
||||
|
||||
class SocialNetworkData(BaseModel):
|
||||
"""社交网络数据"""
|
||||
|
||||
total_groups: int = Field(0, description="加入的群组总数")
|
||||
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
|
||||
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
|
||||
at_count: int = Field(0, description="被@次数")
|
||||
mentioned_count: int = Field(0, description="被提及次数")
|
||||
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
|
||||
longest_companion_days: int = Field(0, description="陪伴天数")
|
||||
|
||||
|
||||
class BrainPowerData(BaseModel):
|
||||
"""最强大脑数据"""
|
||||
|
||||
total_tokens: int = Field(0, description="年度消耗Token总量")
|
||||
total_cost: float = Field(0.0, description="年度总花费")
|
||||
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
|
||||
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
|
||||
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
|
||||
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
|
||||
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
|
||||
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
|
||||
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
|
||||
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
|
||||
total_actions: int = Field(0, description="总动作数")
|
||||
no_reply_count: int = Field(0, description="选择沉默的次数")
|
||||
avg_interest_value: float = Field(0.0, description="平均兴趣值")
|
||||
max_interest_value: float = Field(0.0, description="最高兴趣值")
|
||||
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
|
||||
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
|
||||
max_reasoning_length: int = Field(0, description="最长思考长度")
|
||||
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
|
||||
|
||||
|
||||
class ExpressionVibeData(BaseModel):
|
||||
"""个性与表达数据"""
|
||||
|
||||
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
|
||||
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
|
||||
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
|
||||
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
|
||||
checked_expression_count: int = Field(0, description="已检查的表达次数")
|
||||
total_expressions: int = Field(0, description="表达总数")
|
||||
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
|
||||
image_processed_count: int = Field(0, description="处理的图片数量")
|
||||
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
|
||||
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
|
||||
|
||||
|
||||
class AchievementData(BaseModel):
|
||||
"""趣味成就数据"""
|
||||
|
||||
new_jargon_count: int = Field(0, description="新学到的黑话数量")
|
||||
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
|
||||
|
||||
class AnnualReportData(BaseModel):
|
||||
"""年度报告完整数据"""
|
||||
|
||||
year: int = Field(2025, description="报告年份")
|
||||
bot_name: str = Field("麦麦", description="Bot名称")
|
||||
generated_at: str = Field(..., description="报告生成时间")
|
||||
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
|
||||
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
|
||||
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
|
||||
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
|
||||
achievements: AchievementData = Field(default_factory=AchievementData)
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
|
||||
"""获取指定年份的时间戳范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
|
||||
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
|
||||
return start, end
|
||||
|
||||
|
||||
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
|
||||
"""获取指定年份的 datetime 范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0)
|
||||
end = datetime(year, 12, 31, 23, 59, 59)
|
||||
return start, end
|
||||
|
||||
|
||||
# ==================== 维度一:时光足迹 ====================
|
||||
|
||||
|
||||
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
||||
"""获取时光足迹数据"""
|
||||
data = TimeFootprintData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度在线时长
|
||||
online_records = list(
|
||||
OnlineTime.select().where(
|
||||
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
|
||||
)
|
||||
)
|
||||
total_seconds = 0
|
||||
for record in online_records:
|
||||
try:
|
||||
start = max(record.start_timestamp, start_dt)
|
||||
end = min(record.end_timestamp, end_dt)
|
||||
if end > start:
|
||||
total_seconds += (end - start).total_seconds()
|
||||
except Exception:
|
||||
continue
|
||||
data.total_online_hours = round(total_seconds / 3600, 2)
|
||||
|
||||
# 2. 初次相遇 - 年度第一条消息
|
||||
first_msg = (
|
||||
Messages.select()
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.order_by(Messages.time.asc())
|
||||
.first()
|
||||
)
|
||||
if first_msg:
|
||||
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
|
||||
content = first_msg.processed_plain_text or first_msg.display_message or ""
|
||||
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
# 3. 最忙碌的一天
|
||||
# 使用 SQLite 的 date 函数按日期分组
|
||||
busiest_query = (
|
||||
Messages.select(
|
||||
fn.date(Messages.time, "unixepoch").alias("day"),
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.group_by(fn.date(Messages.time, "unixepoch"))
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(1)
|
||||
)
|
||||
busiest_result = list(busiest_query.dicts())
|
||||
if busiest_result:
|
||||
data.busiest_day = busiest_result[0].get("day")
|
||||
data.busiest_day_count = busiest_result[0].get("count", 0)
|
||||
|
||||
# 4. 昼夜节律 - 24小时活跃分布
|
||||
hourly_query = (
|
||||
Messages.select(
|
||||
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
|
||||
)
|
||||
hourly_distribution = [0] * 24
|
||||
for row in hourly_query.dicts():
|
||||
try:
|
||||
hour = int(row.get("hour", 0))
|
||||
if 0 <= hour < 24:
|
||||
hourly_distribution[hour] = row.get("count", 0)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
data.hourly_distribution = hourly_distribution
|
||||
|
||||
# 5. 深夜食堂 (0-4点)
|
||||
data.midnight_chat_count = sum(hourly_distribution[0:5])
|
||||
|
||||
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
|
||||
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
|
||||
morning_activity = sum(hourly_distribution[6:13])
|
||||
data.is_night_owl = night_activity > morning_activity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度二:社交网络 ====================
|
||||
|
||||
|
||||
async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||
"""获取社交网络数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = SocialNetworkData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于过滤
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 加入的群组总数
|
||||
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
|
||||
|
||||
# 2. 话痨群组 TOP3
|
||||
top_groups_query = (
|
||||
Messages.select(
|
||||
Messages.chat_info_group_id,
|
||||
Messages.chat_info_group_name,
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.chat_info_group_id.is_null(False))
|
||||
)
|
||||
.group_by(Messages.chat_info_group_id)
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_groups = [
|
||||
{
|
||||
"group_id": row["chat_info_group_id"],
|
||||
"group_name": row["chat_info_group_name"] or "未知群组",
|
||||
"message_count": row["count"],
|
||||
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
|
||||
}
|
||||
for row in top_groups_query.dicts()
|
||||
]
|
||||
|
||||
# 3. 互动最多的用户 TOP5(过滤 bot 自身)
|
||||
top_users_query = (
|
||||
Messages.select(
|
||||
Messages.user_id,
|
||||
Messages.user_nickname,
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.user_id.is_null(False))
|
||||
& (Messages.user_id != bot_qq) # 过滤 bot 自身
|
||||
)
|
||||
.group_by(Messages.user_id)
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_users = [
|
||||
{
|
||||
"user_id": row["user_id"],
|
||||
"user_nickname": row["user_nickname"] or "未知用户",
|
||||
"message_count": row["count"],
|
||||
"is_webui": str(row["user_id"]).startswith("webui_"),
|
||||
}
|
||||
for row in top_users_query.dicts()
|
||||
]
|
||||
|
||||
# 4. 被@次数
|
||||
data.at_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_at == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 5. 被提及次数
|
||||
data.mentioned_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_mentioned == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 6. 最长情陪伴的用户(过滤 bot 自身)
|
||||
companion_query = (
|
||||
ChatStreams.select(
|
||||
ChatStreams.user_id,
|
||||
ChatStreams.user_nickname,
|
||||
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
|
||||
)
|
||||
.where(
|
||||
(ChatStreams.user_id.is_null(False))
|
||||
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
|
||||
)
|
||||
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
|
||||
.limit(1)
|
||||
)
|
||||
companion_result = list(companion_query.dicts())
|
||||
if companion_result:
|
||||
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
|
||||
duration = companion_result[0].get("duration", 0) or 0
|
||||
data.longest_companion_days = int(duration / 86400) # 转换为天
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度三:最强大脑 ====================
|
||||
|
||||
|
||||
async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
||||
"""获取最强大脑数据"""
|
||||
data = BrainPowerData()
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度消耗 Token 总量和总花费
|
||||
token_query = LLMUsage.select(
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
result = token_query.dicts().get()
|
||||
data.total_tokens = int(result.get("total_tokens", 0) or 0)
|
||||
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
|
||||
|
||||
# 2. 最爱用的模型
|
||||
model_query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
||||
fn.COUNT(LLMUsage.id).alias("count"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
model_results = list(model_query.dicts())
|
||||
if model_results:
|
||||
data.favorite_model = model_results[0].get("model")
|
||||
data.favorite_model_count = model_results[0].get("count", 0)
|
||||
data.model_distribution = [
|
||||
{
|
||||
"model": row["model"],
|
||||
"count": row["count"],
|
||||
"tokens": row["tokens"],
|
||||
"cost": round(row["cost"], 4),
|
||||
}
|
||||
for row in model_results
|
||||
]
|
||||
|
||||
# 3. 最昂贵的一次思考
|
||||
expensive_query = (
|
||||
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
|
||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
.order_by(LLMUsage.cost.desc())
|
||||
.limit(1)
|
||||
)
|
||||
expensive_result = expensive_query.first()
|
||||
if expensive_result:
|
||||
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
|
||||
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
|
||||
consumer_query = (
|
||||
LLMUsage.select(
|
||||
LLMUsage.user_id,
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_dt)
|
||||
& (LLMUsage.timestamp <= end_dt)
|
||||
& (LLMUsage.user_id != "system") # 过滤 system 用户
|
||||
& (LLMUsage.user_id.is_null(False))
|
||||
)
|
||||
.group_by(LLMUsage.user_id)
|
||||
.order_by(fn.SUM(LLMUsage.cost).desc())
|
||||
.limit(3)
|
||||
)
|
||||
data.top_token_consumers = [
|
||||
{
|
||||
"user_id": row["user_id"],
|
||||
"cost": round(row["cost"], 4),
|
||||
"tokens": row["tokens"],
|
||||
}
|
||||
for row in consumer_query.dicts()
|
||||
]
|
||||
|
||||
# 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用)
|
||||
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
|
||||
reply_model_query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
||||
fn.COUNT(LLMUsage.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_dt)
|
||||
& (LLMUsage.timestamp <= end_dt)
|
||||
& (
|
||||
LLMUsage.model_assign_name.contains("replyer")
|
||||
| LLMUsage.model_assign_name.contains("回复")
|
||||
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
|
||||
)
|
||||
)
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_reply_models = [
|
||||
{"model": row["model"], "count": row["count"]}
|
||||
for row in reply_model_query.dicts()
|
||||
]
|
||||
|
||||
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
||||
total_actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
|
||||
).count()
|
||||
no_reply_count = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "no_reply")
|
||||
).count()
|
||||
data.total_actions = total_actions
|
||||
data.no_reply_count = no_reply_count
|
||||
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
||||
|
||||
# 6. 情绪波动 (兴趣值)
|
||||
interest_query = Messages.select(
|
||||
fn.AVG(Messages.interest_value).alias("avg_interest"),
|
||||
fn.MAX(Messages.interest_value).alias("max_interest"),
|
||||
).where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.interest_value.is_null(False))
|
||||
)
|
||||
interest_result = interest_query.dicts().get()
|
||||
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
|
||||
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
|
||||
|
||||
# 找到最高兴趣值的时间
|
||||
if data.max_interest_value > 0:
|
||||
max_interest_msg = (
|
||||
Messages.select(Messages.time)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.interest_value == data.max_interest_value)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if max_interest_msg:
|
||||
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# 7. 思考深度 (基于 action_reasoning 长度)
|
||||
reasoning_records = (
|
||||
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_reasoning.is_null(False))
|
||||
& (ActionRecords.action_reasoning != "")
|
||||
)
|
||||
)
|
||||
reasoning_lengths = []
|
||||
max_len = 0
|
||||
max_len_time = None
|
||||
for record in reasoning_records:
|
||||
if record.action_reasoning:
|
||||
length = len(record.action_reasoning)
|
||||
reasoning_lengths.append(length)
|
||||
if length > max_len:
|
||||
max_len = length
|
||||
max_len_time = record.time
|
||||
|
||||
if reasoning_lengths:
|
||||
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
||||
data.max_reasoning_length = max_len
|
||||
if max_len_time:
|
||||
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度四:个性与表达 ====================
|
||||
|
||||
|
||||
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||
"""获取个性与表达数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = ExpressionVibeData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 表情包之王 - 使用次数最多的表情包
|
||||
top_emoji_query = (
|
||||
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
|
||||
.where(Emoji.is_registered == True)
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_emojis = list(top_emoji_query.dicts())
|
||||
if top_emojis:
|
||||
data.top_emoji = {
|
||||
"id": top_emojis[0].get("id"),
|
||||
"path": top_emojis[0].get("full_path"),
|
||||
"description": top_emojis[0].get("description"),
|
||||
"usage_count": top_emojis[0].get("usage_count", 0),
|
||||
"hash": top_emojis[0].get("emoji_hash"),
|
||||
}
|
||||
data.top_emojis = [
|
||||
{
|
||||
"id": e.get("id"),
|
||||
"path": e.get("full_path"),
|
||||
"description": e.get("description"),
|
||||
"usage_count": e.get("usage_count", 0),
|
||||
"hash": e.get("emoji_hash"),
|
||||
}
|
||||
for e in top_emojis
|
||||
]
|
||||
|
||||
# 2. 百变麦麦 - 最常用的表达风格
|
||||
expression_query = (
|
||||
Expression.select(
|
||||
Expression.style,
|
||||
fn.SUM(Expression.count).alias("total_count"),
|
||||
)
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.group_by(Expression.style)
|
||||
.order_by(fn.SUM(Expression.count).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_expressions = [
|
||||
{"style": row["style"], "count": row["total_count"]}
|
||||
for row in expression_query.dicts()
|
||||
]
|
||||
|
||||
# 3. 被拒绝的表达
|
||||
data.rejected_expression_count = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
& (Expression.rejected == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 4. 已检查的表达
|
||||
data.checked_expression_count = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
& (Expression.checked == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 5. 表达总数
|
||||
data.total_expressions = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 6. 动作类型分布 (过滤无意义的动作)
|
||||
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
||||
excluded_actions = [
|
||||
"reply", "no_reply", "no_reply_until_call", "make_question",
|
||||
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
|
||||
]
|
||||
action_query = (
|
||||
ActionRecords.select(
|
||||
ActionRecords.action_name,
|
||||
fn.COUNT(ActionRecords.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name.not_in(excluded_actions))
|
||||
)
|
||||
.group_by(ActionRecords.action_name)
|
||||
.order_by(fn.COUNT(ActionRecords.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
data.action_types = [
|
||||
{"action": row["action_name"], "count": row["count"]}
|
||||
for row in action_query.dicts()
|
||||
]
|
||||
|
||||
# 7. 处理的图片数量
|
||||
data.image_processed_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_picid == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
||||
import random
|
||||
import re
|
||||
|
||||
def clean_message_content(content: str) -> str:
|
||||
"""清理消息内容,移除回复引用等标记"""
|
||||
if not content:
|
||||
return ""
|
||||
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
|
||||
content = re.sub(r'\[回复<[^>]+>\s*的消息[::][^\]]*\]', '', content)
|
||||
# 移除 [图片] [表情] 等标记
|
||||
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
|
||||
# 移除多余的空白
|
||||
content = re.sub(r'\s+', ' ', content).strip()
|
||||
return content
|
||||
|
||||
# 使用 user_id 判断是否是 bot 发送的消息
|
||||
late_night_messages = list(
|
||||
Messages.select(
|
||||
Messages.time,
|
||||
Messages.processed_plain_text,
|
||||
Messages.display_message,
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.user_id == bot_qq) # bot 发送的消息
|
||||
)
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
# 筛选出0-6点的消息
|
||||
late_night_filtered = []
|
||||
for msg in late_night_messages:
|
||||
msg_dt = datetime.fromtimestamp(msg.time)
|
||||
hour = msg_dt.hour
|
||||
if 0 <= hour < 6: # 0点到6点
|
||||
raw_content = msg.processed_plain_text or msg.display_message or ""
|
||||
cleaned_content = clean_message_content(raw_content)
|
||||
# 只保留有意义的内容
|
||||
if cleaned_content and len(cleaned_content) > 2:
|
||||
late_night_filtered.append({
|
||||
"time": msg.time,
|
||||
"hour": hour,
|
||||
"minute": msg_dt.minute,
|
||||
"content": cleaned_content,
|
||||
"datetime_str": msg_dt.strftime("%H:%M"),
|
||||
})
|
||||
if len(late_night_filtered) >= 10:
|
||||
break
|
||||
|
||||
if late_night_filtered:
|
||||
selected = random.choice(late_night_filtered)
|
||||
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
|
||||
data.late_night_reply = {
|
||||
"time": selected["datetime_str"],
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
|
||||
from collections import Counter
|
||||
import json as json_lib
|
||||
|
||||
reply_records = (
|
||||
ActionRecords.select(ActionRecords.action_data)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "reply")
|
||||
& (ActionRecords.action_data.is_null(False))
|
||||
& (ActionRecords.action_data != "")
|
||||
)
|
||||
)
|
||||
|
||||
reply_contents = []
|
||||
for record in reply_records:
|
||||
try:
|
||||
action_data = record.action_data
|
||||
if action_data:
|
||||
content = None
|
||||
# 尝试解析 JSON 格式
|
||||
try:
|
||||
parsed = json_lib.loads(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
# 优先使用 reply_text,其次使用 content
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (json_lib.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
|
||||
# 例如: "{'reply_text': '墨白灵不知道哦'}"
|
||||
if content is None:
|
||||
import ast
|
||||
try:
|
||||
parsed = ast.literal_eval(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (ValueError, SyntaxError):
|
||||
# 无法解析,使用原始字符串
|
||||
content = action_data
|
||||
|
||||
# 只统计有意义的回复(长度大于2)
|
||||
if content and len(content) > 2:
|
||||
reply_contents.append(content)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if reply_contents:
|
||||
content_counter = Counter(reply_contents)
|
||||
most_common = content_counter.most_common(1)
|
||||
if most_common:
|
||||
fav_content, fav_count = most_common[0]
|
||||
# 截断过长的内容
|
||||
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
|
||||
data.favorite_reply = {
|
||||
"content": display_content,
|
||||
"count": fav_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度五:趣味成就 ====================
|
||||
|
||||
|
||||
async def get_achievements(year: int = 2025) -> AchievementData:
|
||||
"""获取趣味成就数据"""
|
||||
data = AchievementData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 新学到的黑话数量
|
||||
# Jargon 表没有时间字段,统计全部已确认的黑话
|
||||
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
|
||||
|
||||
# 2. 代表性黑话示例
|
||||
jargon_samples = (
|
||||
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
|
||||
.where(Jargon.is_jargon == True)
|
||||
.order_by(Jargon.count.desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.sample_jargons = [
|
||||
{
|
||||
"content": j.content,
|
||||
"meaning": j.meaning,
|
||||
"count": j.count,
|
||||
}
|
||||
for j in jargon_samples
|
||||
]
|
||||
|
||||
# 3. 总消息数
|
||||
data.total_messages = (
|
||||
Messages.select()
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 4. 总回复数 (有 reply_to 的消息)
|
||||
data.total_replies = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
|
||||
@router.get("/full", response_model=AnnualReportData)
|
||||
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取完整年度报告数据
|
||||
|
||||
Args:
|
||||
year: 报告年份,默认2025
|
||||
|
||||
Returns:
|
||||
完整的年度报告数据
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
logger.info(f"开始生成 {year} 年度报告...")
|
||||
|
||||
# 获取 bot 名称
|
||||
bot_name = global_config.bot.nickname or "麦麦"
|
||||
|
||||
# 并行获取各维度数据
|
||||
time_footprint = await get_time_footprint(year)
|
||||
social_network = await get_social_network(year)
|
||||
brain_power = await get_brain_power(year)
|
||||
expression_vibe = await get_expression_vibe(year)
|
||||
achievements = await get_achievements(year)
|
||||
|
||||
report = AnnualReportData(
|
||||
year=year,
|
||||
bot_name=bot_name,
|
||||
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
time_footprint=time_footprint,
|
||||
social_network=social_network,
|
||||
brain_power=brain_power,
|
||||
expression_vibe=expression_vibe,
|
||||
achievements=achievements,
|
||||
)
|
||||
|
||||
logger.info(f"{year} 年度报告生成完成")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成年度报告失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/time-footprint", response_model=TimeFootprintData)
|
||||
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取时光足迹数据"""
|
||||
try:
|
||||
return await get_time_footprint(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/social-network", response_model=SocialNetworkData)
|
||||
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取社交网络数据"""
|
||||
try:
|
||||
return await get_social_network(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/brain-power", response_model=BrainPowerData)
|
||||
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取最强大脑数据"""
|
||||
try:
|
||||
return await get_brain_power(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/expression-vibe", response_model=ExpressionVibeData)
|
||||
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取个性与表达数据"""
|
||||
try:
|
||||
return await get_expression_vibe(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/achievements", response_model=AchievementData)
|
||||
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取趣味成就数据"""
|
||||
try:
|
||||
return await get_achievements(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -1,795 +0,0 @@
|
||||
"""
|
||||
WebUI 防爬虫模块
|
||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||
"""
|
||||
|
||||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.anti_crawler")
|
||||
|
||||
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
|
||||
CRAWLER_USER_AGENTS = {
|
||||
# 搜索引擎爬虫(精确匹配)
|
||||
"googlebot",
|
||||
"bingbot",
|
||||
"baiduspider",
|
||||
"yandexbot",
|
||||
"slurp", # Yahoo
|
||||
"duckduckbot",
|
||||
"sogou",
|
||||
"exabot",
|
||||
"facebot",
|
||||
"ia_archiver", # Internet Archive
|
||||
# 通用爬虫(移除过于宽泛的关键词)
|
||||
"crawler",
|
||||
"spider",
|
||||
"scraper",
|
||||
"wget", # 保留wget,因为通常用于自动化脚本
|
||||
"scrapy", # 保留scrapy,因为这是爬虫框架
|
||||
# 安全扫描工具(这些是明确的扫描工具)
|
||||
"masscan",
|
||||
"nmap",
|
||||
"nikto",
|
||||
"sqlmap",
|
||||
# 注意:移除了以下过于宽泛的关键词以避免误报:
|
||||
# - "bot" (会误匹配GitHub-Robot等)
|
||||
# - "curl" (正常工具)
|
||||
# - "python-requests" (正常库)
|
||||
# - "httpx" (正常库)
|
||||
# - "aiohttp" (正常库)
|
||||
}
|
||||
|
||||
# 资产测绘工具 User-Agent 标识
|
||||
ASSET_SCANNER_USER_AGENTS = {
|
||||
# 知名资产测绘平台
|
||||
"shodan",
|
||||
"censys",
|
||||
"zoomeye",
|
||||
"fofa",
|
||||
"quake",
|
||||
"hunter",
|
||||
"binaryedge",
|
||||
"onyphe",
|
||||
"securitytrails",
|
||||
"virustotal",
|
||||
"passivetotal",
|
||||
# 安全扫描工具
|
||||
"acunetix",
|
||||
"appscan",
|
||||
"burpsuite",
|
||||
"nessus",
|
||||
"openvas",
|
||||
"qualys",
|
||||
"rapid7",
|
||||
"tenable",
|
||||
"veracode",
|
||||
"zap",
|
||||
"awvs", # Acunetix Web Vulnerability Scanner
|
||||
"netsparker",
|
||||
"skipfish",
|
||||
"w3af",
|
||||
"arachni",
|
||||
# 其他扫描工具
|
||||
"masscan",
|
||||
"zmap",
|
||||
"nmap",
|
||||
"whatweb",
|
||||
"wpscan",
|
||||
"joomscan",
|
||||
"dnsenum",
|
||||
"subfinder",
|
||||
"amass",
|
||||
"sublist3r",
|
||||
"theharvester",
|
||||
}
|
||||
|
||||
# 资产测绘工具常用的HTTP头标识
|
||||
ASSET_SCANNER_HEADERS = {
|
||||
# 常见的扫描工具自定义头
|
||||
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
|
||||
"x-scanner": {"nmap", "masscan", "zmap"},
|
||||
"x-probe": {"masscan", "zmap"},
|
||||
# 其他可疑头(移除反向代理标准头)
|
||||
"x-originating-ip": set(),
|
||||
"x-remote-ip": set(),
|
||||
"x-remote-addr": set(),
|
||||
# 注意:移除了以下反向代理标准头以避免误报:
|
||||
# - "x-forwarded-proto" (反向代理标准头)
|
||||
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||
}
|
||||
|
||||
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||
# 只检查这些特定头,不检查所有头
|
||||
SCANNER_SPECIFIC_HEADERS = {
|
||||
"x-scan",
|
||||
"x-scanner",
|
||||
"x-probe",
|
||||
"x-originating-ip",
|
||||
"x-remote-ip",
|
||||
"x-remote-addr",
|
||||
}
|
||||
|
||||
# 防爬虫模式配置
|
||||
# false: 禁用
|
||||
# strict: 严格模式(更严格的检测,更低的频率限制)
|
||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||
# - IPv6:::1, 2001:db8::/32
|
||||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
Args:
|
||||
ip_string: 逗号分隔的IP字符串
|
||||
|
||||
Returns:
|
||||
IP白名单列表,每个元素可能是:
|
||||
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||
- ipaddress.IPv4Address/IPv6Address对象(精确IP)
|
||||
- str(通配符模式,已转换为正则表达式)
|
||||
"""
|
||||
allowed = []
|
||||
if not ip_string:
|
||||
return allowed
|
||||
|
||||
for ip_entry in ip_string.split(","):
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
|
||||
# 检查通配符格式(包含*)
|
||||
if "*" in ip_entry:
|
||||
# 处理通配符
|
||||
pattern = _convert_wildcard_to_regex(ip_entry)
|
||||
if pattern:
|
||||
allowed.append(pattern)
|
||||
else:
|
||||
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析为CIDR格式(包含/)
|
||||
if "/" in ip_entry:
|
||||
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
|
||||
else:
|
||||
# 精确IP地址
|
||||
allowed.append(ipaddress.ip_address(ip_entry))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||
|
||||
return allowed
|
||||
|
||||
|
||||
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
"""
|
||||
将通配符IP模式转换为正则表达式
|
||||
|
||||
支持的格式:
|
||||
- 192.168.*.* 或 192.168.*
|
||||
- 10.*.*.* 或 10.*
|
||||
- *.*.*.* 或 *
|
||||
|
||||
Args:
|
||||
wildcard_pattern: 通配符模式字符串
|
||||
|
||||
Returns:
|
||||
正则表达式字符串,如果格式无效则返回None
|
||||
"""
|
||||
# 去除空格
|
||||
pattern = wildcard_pattern.strip()
|
||||
|
||||
# 处理单个*(匹配所有)
|
||||
if pattern == "*":
|
||||
return r".*"
|
||||
|
||||
# 处理IPv4通配符格式
|
||||
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||
parts = pattern.split(".")
|
||||
|
||||
if len(parts) > 4:
|
||||
return None # IPv4最多4段
|
||||
|
||||
# 构建正则表达式
|
||||
regex_parts = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part == "*":
|
||||
regex_parts.append(r"\d+") # 匹配任意数字
|
||||
elif part.isdigit():
|
||||
# 验证数字范围(0-255)
|
||||
num = int(part)
|
||||
if 0 <= num <= 255:
|
||||
regex_parts.append(re.escape(part))
|
||||
else:
|
||||
return None # 无效的数字
|
||||
else:
|
||||
return None # 无效的格式
|
||||
|
||||
# 如果部分少于4段,补充.*
|
||||
while len(regex_parts) < 4:
|
||||
regex_parts.append(r"\d+")
|
||||
|
||||
# 组合成正则表达式
|
||||
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||
return regex
|
||||
|
||||
|
||||
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
|
||||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
}
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
"""
|
||||
根据模式获取配置参数
|
||||
|
||||
Args:
|
||||
mode: 防爬虫模式 (false/strict/loose/basic)
|
||||
|
||||
Returns:
|
||||
配置字典,包含所有相关参数
|
||||
"""
|
||||
mode = mode.lower()
|
||||
|
||||
if mode == "false":
|
||||
return {
|
||||
"enabled": False,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
|
||||
"max_tracked_ips": 0,
|
||||
"check_user_agent": False,
|
||||
"check_asset_scanner": False,
|
||||
"check_rate_limit": False,
|
||||
"block_on_detect": False, # 不阻止
|
||||
}
|
||||
elif mode == "strict":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
|
||||
"max_tracked_ips": 20000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
elif mode == "loose":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
|
||||
"max_tracked_ips": 5000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
else: # basic (默认模式)
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 不限制请求数
|
||||
"max_tracked_ips": 0, # 不跟踪IP
|
||||
"check_user_agent": True, # 检测但不阻止
|
||||
"check_asset_scanner": True, # 检测但不阻止
|
||||
"check_rate_limit": False, # 不限制请求频率
|
||||
"block_on_detect": False, # 只记录,不阻止
|
||||
}
|
||||
|
||||
|
||||
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
"""防爬虫中间件"""
|
||||
|
||||
def __init__(self, app, mode: str = "standard"):
|
||||
"""
|
||||
初始化防爬虫中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
mode: 防爬虫模式 (false/strict/loose/standard)
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.mode = mode.lower()
|
||||
# 根据模式获取配置
|
||||
config = _get_mode_config(self.mode)
|
||||
self.enabled = config["enabled"]
|
||||
self.rate_limit_window = config["rate_limit_window"]
|
||||
self.rate_limit_max_requests = config["rate_limit_max_requests"]
|
||||
self.max_tracked_ips = config["max_tracked_ips"]
|
||||
self.check_user_agent = config["check_user_agent"]
|
||||
self.check_asset_scanner = config["check_asset_scanner"]
|
||||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||
|
||||
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
检测是否为爬虫 User-Agent
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent 字符串
|
||||
|
||||
Returns:
|
||||
如果是爬虫则返回 True
|
||||
"""
|
||||
if not user_agent:
|
||||
# 没有 User-Agent 的请求记录日志但不直接阻止
|
||||
# 改为只记录,让频率限制来处理
|
||||
logger.debug("请求缺少User-Agent")
|
||||
return False # 不再直接阻止无User-Agent的请求
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# 使用集合查找提高性能(检查是否包含爬虫关键词)
|
||||
for crawler_keyword in self.crawler_keywords_set:
|
||||
if crawler_keyword in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
如果检测到资产测绘工具头则返回 True
|
||||
"""
|
||||
# 只检查特定的扫描工具头,不检查所有头
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
header_value_lower = header_value.lower() if header_value else ""
|
||||
|
||||
# 检查已知的扫描工具头
|
||||
if header_name_lower in ASSET_SCANNER_HEADERS:
|
||||
# 如果该头有特定的工具集合,检查值是否匹配
|
||||
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
|
||||
if expected_tools:
|
||||
for tool in expected_tools:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
else:
|
||||
# 如果没有特定工具集合,只要存在该头就视为可疑
|
||||
if header_value_lower:
|
||||
return True
|
||||
|
||||
# 只检查特定头中的可疑模式(收紧匹配)
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
# 检查头值中是否包含已知扫描工具名称
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检测资产测绘工具
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
(是否检测到, 检测到的工具名称)
|
||||
"""
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检查 User-Agent(使用集合查找提高性能)
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for scanner_keyword in self.scanner_keywords_set:
|
||||
if scanner_keyword in user_agent_lower:
|
||||
return True, scanner_keyword
|
||||
|
||||
# 检查HTTP头
|
||||
if self._is_asset_scanner_header(request):
|
||||
# 尝试从User-Agent或头中提取工具名称
|
||||
detected_tool = None
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in user_agent_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
|
||||
# 检查HTTP头中的工具标识(只检查特定头)
|
||||
if not detected_tool:
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
|
||||
return True, detected_tool or "unknown_scanner"
|
||||
|
||||
return False, None
|
||||
|
||||
def _check_rate_limit(self, client_ip: str) -> bool:
|
||||
"""
|
||||
检查请求频率限制
|
||||
|
||||
Args:
|
||||
client_ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果超过限制则返回 True(需要阻止)
|
||||
"""
|
||||
# 检查IP白名单
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||
if current_time - self.last_cleanup > 300:
|
||||
self._cleanup_old_requests(current_time)
|
||||
self.last_cleanup = current_time
|
||||
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录(删除最久未访问的IP)
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
# 获取或创建该IP的请求时间deque(不使用maxlen,避免限流变松)
|
||||
if client_ip not in self.request_times:
|
||||
self.request_times[client_ip] = deque()
|
||||
|
||||
request_times = self.request_times[client_ip]
|
||||
|
||||
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= self.rate_limit_max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求时间
|
||||
request_times.append(current_time)
|
||||
return False
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float):
|
||||
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||
# 清理最久未访问的IP记录
|
||||
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
def _cleanup_oldest_ips(self):
|
||||
"""清理最久未访问的IP记录(全量遍历找真正的oldest)"""
|
||||
if not self.request_times:
|
||||
return
|
||||
|
||||
# 先收集空deque的IP(优先删除)
|
||||
empty_ips = []
|
||||
# 找到最久未访问的IP(最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||
for ip, times in self.request_times.items():
|
||||
if not times:
|
||||
# 空deque,记录待删除
|
||||
empty_ips.append(ip)
|
||||
else:
|
||||
# 找到最旧的时间戳
|
||||
if times[0] < oldest_time:
|
||||
oldest_time = times[0]
|
||||
oldest_ip = ip
|
||||
|
||||
# 先删除空deque的IP
|
||||
for ip in empty_ips:
|
||||
del self.request_times[ip]
|
||||
|
||||
# 如果没有空deque可删除,且仍需要清理,删除最旧的一个IP
|
||||
if not empty_ips and oldest_ip:
|
||||
del self.request_times[oldest_ip]
|
||||
|
||||
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在信任的代理列表中
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果是信任的代理则返回 True
|
||||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(trusted_entry, str):
|
||||
try:
|
||||
if re.match(trusted_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
客户端IP地址
|
||||
"""
|
||||
# 获取直接连接的客户端IP(用于验证代理)
|
||||
direct_client_ip = None
|
||||
if request.client:
|
||||
direct_client_ip = request.client.host
|
||||
|
||||
# 检查是否信任X-Forwarded-For头
|
||||
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
|
||||
use_xff = False
|
||||
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
|
||||
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
|
||||
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||
|
||||
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||
if use_xff:
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取(如果信任代理)
|
||||
if use_xff:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用直接连接的客户端IP
|
||||
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||
return direct_client_ip
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _validate_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
验证IP地址格式
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果格式有效则返回 True
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _is_ip_allowed(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在白名单中(支持精确IP、CIDR格式和通配符)
|
||||
|
||||
Args:
|
||||
ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果IP在白名单中则返回 True
|
||||
"""
|
||||
if not ALLOWED_IPS or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查白名单中的每个条目
|
||||
for allowed_entry in ALLOWED_IPS:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(allowed_entry, str):
|
||||
try:
|
||||
if re.match(allowed_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
# 正则表达式错误,跳过
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件或路由处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
# 如果未启用,直接通过
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问 robots.txt(由专门的路由处理)
|
||||
if request.url.path == "/robots.txt":
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问静态资源(CSS、JS、图片等)
|
||||
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||
static_extensions = {
|
||||
".css",
|
||||
".js",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
".woff",
|
||||
".woff2",
|
||||
".ttf",
|
||||
".eot",
|
||||
}
|
||||
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||
|
||||
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||
path = request.url.path
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||
path.endswith(ext) for ext in static_extensions
|
||||
)
|
||||
|
||||
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
if is_static_path or is_root_static:
|
||||
return await call_next(request)
|
||||
|
||||
# 获取客户端IP(只获取一次,避免重复调用)
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# 检查IP白名单(优先检查,白名单IP直接通过)
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取 User-Agent
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检测资产测绘工具(优先检测,因为更危险)
|
||||
if self.check_asset_scanner:
|
||||
is_scanner, scanner_name = self._detect_asset_scanner(request)
|
||||
if is_scanner:
|
||||
logger.warning(
|
||||
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
|
||||
f"User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Asset scanning tools are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检测爬虫 User-Agent
|
||||
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Crawlers are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检查请求频率限制
|
||||
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
return PlainTextResponse(
|
||||
"Too Many Requests: Rate limit exceeded",
|
||||
status_code=429,
|
||||
)
|
||||
|
||||
# 正常请求,继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def create_robots_txt_response() -> PlainTextResponse:
|
||||
"""
|
||||
创建 robots.txt 响应
|
||||
|
||||
Returns:
|
||||
robots.txt 响应对象
|
||||
"""
|
||||
robots_content = """User-agent: *
|
||||
Disallow: /
|
||||
|
||||
# 禁止所有爬虫访问
|
||||
"""
|
||||
return PlainTextResponse(
|
||||
content=robots_content,
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||
)
|
||||
@@ -1,301 +0,0 @@
|
||||
"""
|
||||
规划器监控API
|
||||
提供规划器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/planner", tags=["planner"])
|
||||
|
||||
# 规划器日志目录
|
||||
PLAN_LOG_DIR = Path("logs/plan")
|
||||
|
||||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
action_count: int
|
||||
action_types: List[str] # 动作类型列表
|
||||
total_plan_ms: float
|
||||
llm_duration_ms: float
|
||||
reasoning_preview: str
|
||||
|
||||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
reasoning: str
|
||||
raw_output: str
|
||||
actions: List[Dict]
|
||||
timing: Dict
|
||||
extra: Optional[Dict] = None
|
||||
|
||||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
|
||||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=PlannerOverview)
|
||||
async def get_planner_overview():
|
||||
"""
|
||||
获取规划器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_plans = 0
|
||||
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
plan_count = len(json_files)
|
||||
total_plans += plan_count
|
||||
|
||||
if plan_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
async def get_chat_plan_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
async def get_log_detail(chat_id: str, filename: str):
|
||||
"""获取规划日志详情 - 按需加载完整内容"""
|
||||
log_file = PLAN_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
|
||||
# 获取最近10条计划的摘要
|
||||
recent_plans = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_plans.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_plans = recent_plans[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_chat_list():
|
||||
"""获取所有聊天ID列表 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
# 收集所有文件
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
@@ -1,269 +0,0 @@
|
||||
"""
|
||||
回复器监控API
|
||||
提供回复器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/replier", tags=["replier"])
|
||||
|
||||
# 回复器日志目录
|
||||
REPLY_LOG_DIR = Path("logs/reply")
|
||||
|
||||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
model: str
|
||||
success: bool
|
||||
llm_ms: float
|
||||
overall_ms: float
|
||||
output_preview: str
|
||||
|
||||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
output: str
|
||||
processed_output: List[str]
|
||||
model: str
|
||||
reasoning: str
|
||||
think_level: int
|
||||
timing: Dict
|
||||
error: Optional[str] = None
|
||||
success: bool
|
||||
|
||||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
|
||||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=ReplierOverview)
|
||||
async def get_replier_overview():
|
||||
"""
|
||||
获取回复器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not REPLY_LOG_DIR.exists():
|
||||
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_replies = 0
|
||||
|
||||
for chat_dir in REPLY_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
reply_count = len(json_files)
|
||||
total_replies += reply_count
|
||||
|
||||
if reply_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
async def get_chat_reply_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
async def get_reply_log_detail(chat_id: str, filename: str):
|
||||
"""获取回复日志详情 - 按需加载完整内容"""
|
||||
log_file = REPLY_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
overview = await get_replier_overview()
|
||||
|
||||
# 获取最近10条回复的摘要
|
||||
recent_replies = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_replies.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_replies = recent_replies[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_replier_chat_list():
|
||||
"""获取所有聊天ID列表"""
|
||||
overview = await get_replier_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
@@ -1,183 +0,0 @@
|
||||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
|
||||
# Cookie 配置
|
||||
COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 从配置读取
|
||||
if global_config.webui.secure_cookie:
|
||||
logger.info("配置中启用了 secure_cookie")
|
||||
return True
|
||||
|
||||
# 检查是否是生产环境
|
||||
if global_config.webui.mode == "production":
|
||||
logger.info("WebUI运行在生产模式,启用 secure cookie")
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
logger.debug("WebUI运行在开发模式,禁用 secure cookie")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
request: FastAPI Request 对象(可选,用于检测协议)
|
||||
"""
|
||||
# 根据环境和实际请求协议决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
# 如果提供了 request,检测实际使用的协议
|
||||
if request:
|
||||
# 检查 X-Forwarded-Proto header(代理/负载均衡器)
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
|
||||
if forwarded_proto:
|
||||
is_https = forwarded_proto == "https"
|
||||
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
|
||||
else:
|
||||
# 检查 request.url.scheme
|
||||
is_https = request.url.scheme == "https"
|
||||
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
|
||||
|
||||
# 如果是 HTTP 连接,强制禁用 secure 标志
|
||||
if not is_https and is_secure:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
|
||||
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
|
||||
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
|
||||
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
|
||||
logger.warning("=" * 80)
|
||||
is_secure = False
|
||||
|
||||
# 设置 Cookie
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
|
||||
secure=is_secure, # 根据实际协议决定
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
|
||||
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
|
||||
logger.debug(f"完整 token 前缀: {token[:20]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
@@ -1,782 +0,0 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话
|
||||
|
||||
支持两种模式:
|
||||
1. WebUI 模式:使用 WebUI 平台独立身份聊天
|
||||
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
||||
# 虚拟身份模式的群 ID 前缀
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
# 固定的 WebUI 用户 ID 前缀
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
enabled: bool = False # 是否启用虚拟身份模式
|
||||
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
|
||||
person_id: Optional[str] = None # PersonInfo 的 person_id
|
||||
user_id: Optional[str] = None # 原始平台用户 ID
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定)
|
||||
group_name: Optional[str] = None # 虚拟群名(用户自定义)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
# 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息
|
||||
bot_qq = str(global_config.bot.qq_account)
|
||||
is_bot = user_id == bot_qq
|
||||
else:
|
||||
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.time,
|
||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录
|
||||
|
||||
Args:
|
||||
limit: 获取的消息数量
|
||||
group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
# 查询指定群的消息,按时间排序
|
||||
messages = (
|
||||
Messages.select()
|
||||
.where(Messages.chat_info_group_id == target_group_id)
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局聊天历史管理器
|
||||
chat_history = ChatHistoryManager()
|
||||
|
||||
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
user_id: 用户 ID
|
||||
user_name: 用户昵称
|
||||
message_id: 消息 ID(可选,自动生成)
|
||||
is_at_bot: 是否 @ 机器人
|
||||
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# 确定使用的平台、群信息和用户信息
|
||||
if virtual_config and virtual_config.enabled:
|
||||
# 虚拟身份模式:使用真实平台身份
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
# 标准 WebUI 模式
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
如果指定了 group_id,则获取该虚拟群的历史记录
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
from peewee import fn
|
||||
|
||||
# 查询所有不同的平台
|
||||
platforms = (
|
||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(fn.COUNT(PersonInfo.id).desc())
|
||||
)
|
||||
|
||||
result = []
|
||||
for p in platforms:
|
||||
if p.platform: # 排除空平台
|
||||
result.append({"platform": p.platform, "count": p.count})
|
||||
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
Args:
|
||||
platform: 平台名称(如 qq, discord 等)
|
||||
search: 搜索关键词(匹配昵称、用户名、user_id)
|
||||
limit: 返回数量限制
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = PersonInfo.select().where(PersonInfo.platform == platform)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.limit(limit)
|
||||
|
||||
result = []
|
||||
for person in query:
|
||||
result.append(
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.nickname or person.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清空 {deleted} 条聊天记录",
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
platform: 虚拟身份模式的平台(可选)
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
# 当前会话的虚拟身份配置(可通过消息动态更新)
|
||||
current_virtual_config: Optional[VirtualIdentityConfig] = None
|
||||
|
||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||
if platform and person_id:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
try:
|
||||
# 构建会话信息
|
||||
session_info_data = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
# 如果有虚拟身份配置,添加到会话信息中
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = current_virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
}
|
||||
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, session_info_data)
|
||||
|
||||
# 发送历史记录(根据模式选择不同的群)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
else:
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
|
||||
else:
|
||||
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": welcome_msg,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
# 确定发送者信息(根据是否使用虚拟身份)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
sender_name = current_virtual_config.user_nickname or current_user_name
|
||||
sender_user_id = current_virtual_config.user_id or user_id
|
||||
else:
|
||||
sender_name = current_user_name
|
||||
sender_user_id = user_id
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
user_name=current_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "set_virtual_identity":
|
||||
# 设置或更新虚拟身份配置
|
||||
virtual_data = data.get("config", {})
|
||||
if virtual_data.get("enabled"):
|
||||
# 验证必要字段
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取用户信息
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
|
||||
if not person:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"找不到用户: {virtual_data.get('person_id')}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成虚拟群 ID
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
if custom_group_id:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
else:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
|
||||
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||
)
|
||||
|
||||
# 发送虚拟身份已激活的消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 加载虚拟群的历史记录
|
||||
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": virtual_history,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送系统消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"设置虚拟身份失败: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 禁用虚拟身份模式
|
||||
current_virtual_config = None
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 重新加载默认聊天室历史
|
||||
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": default_history,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
@@ -1,597 +0,0 @@
|
||||
"""
|
||||
配置管理API路由
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ChatConfig,
|
||||
MessageReceiveConfig,
|
||||
EmojiConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
ExperimentalConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(Config)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
支持的section_name:
|
||||
- bot: BotConfig
|
||||
- personality: PersonalityConfig
|
||||
- relationship: RelationshipConfig
|
||||
- chat: ChatConfig
|
||||
- message_receive: MessageReceiveConfig
|
||||
- emoji: EmojiConfig
|
||||
- expression: ExpressionConfig
|
||||
- keyword_reaction: KeywordReactionConfig
|
||||
- chinese_typo: ChineseTypoConfig
|
||||
- response_post_process: ResponsePostProcessConfig
|
||||
- response_splitter: ResponseSplitterConfig
|
||||
- telemetry: TelemetryConfig
|
||||
- experimental: ExperimentalConfig
|
||||
- maim_message: MaimMessageConfig
|
||||
- lpmm_knowledge: LPMMKnowledgeConfig
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
- api_provider: APIProvider
|
||||
- model_info: ModelInfo
|
||||
"""
|
||||
section_map = {
|
||||
"bot": BotConfig,
|
||||
"personality": PersonalityConfig,
|
||||
"relationship": RelationshipConfig,
|
||||
"chat": ChatConfig,
|
||||
"message_receive": MessageReceiveConfig,
|
||||
"emoji": EmojiConfig,
|
||||
"expression": ExpressionConfig,
|
||||
"keyword_reaction": KeywordReactionConfig,
|
||||
"chinese_typo": ChineseTypoConfig,
|
||||
"response_post_process": ResponsePostProcessConfig,
|
||||
"response_splitter": ResponseSplitterConfig,
|
||||
"telemetry": TelemetryConfig,
|
||||
"experimental": ExperimentalConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"lpmm_knowledge": LPMMKnowledgeConfig,
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"voice": VoiceConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
}
|
||||
|
||||
if section_name not in section_map:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
try:
|
||||
config_class = section_map[section_name]
|
||||
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("麦麦主程序配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("模型配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组类型(如 platforms, aliases),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
raw_content = f.read()
|
||||
|
||||
return {"success": True, "content": raw_content}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(raw_content)
|
||||
|
||||
logger.info("麦麦主程序配置已更新(原始模式)")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
|
||||
|
||||
def _normalize_adapter_path(path: str) -> str:
|
||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
# 如果已经是绝对路径,直接返回
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
|
||||
def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
return path
|
||||
|
||||
try:
|
||||
# 尝试获取相对路径
|
||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||
if not rel_path.startswith(".."):
|
||||
return rel_path
|
||||
except (ValueError, TypeError):
|
||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||
pass
|
||||
|
||||
# 无法转换为相对路径,返回绝对路径
|
||||
return path
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
if not os.path.exists(webui_data_path):
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
adapter_config_path = webui_data.get("adapter_config_path")
|
||||
if not adapter_config_path:
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
import datetime
|
||||
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
# 返回相对路径(如果可能)
|
||||
display_path = _to_relative_path(abs_path)
|
||||
return {"success": True, "path": display_path, "lastModified": last_modified}
|
||||
else:
|
||||
# 文件不存在,返回原路径
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
|
||||
# 保存到 data/webui.json
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
import json
|
||||
|
||||
# 读取现有数据
|
||||
if os.path.exists(webui_data_path):
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
else:
|
||||
webui_data = {}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = save_path
|
||||
|
||||
# 保存
|
||||
os.makedirs("data", exist_ok=True)
|
||||
with open(webui_data_path, "w", encoding="utf-8") as f:
|
||||
json.dump(webui_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"适配器配置路径已保存: {save_path}(绝对路径: {abs_path})")
|
||||
return {"success": True, "message": "路径已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 读取文件内容
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
logger.info(f"已读取适配器配置: {path} (绝对路径: {abs_path})")
|
||||
return {"success": True, "content": content}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
content = data.get("content")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
tomlkit.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(abs_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(abs_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"适配器配置已保存: {path} (绝对路径: {abs_path})")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||
@@ -1,335 +0,0 @@
|
||||
"""
|
||||
配置架构生成器 - 自动从配置类生成前端表单架构
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from dataclasses import fields, MISSING
|
||||
from typing import Any, get_origin, get_args, Literal, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
|
||||
class FieldType(str, Enum):
|
||||
"""字段类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
TEXTAREA = "textarea"
|
||||
|
||||
|
||||
class FieldSchema:
|
||||
"""字段架构"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
type: FieldType,
|
||||
label: str,
|
||||
description: str = "",
|
||||
default: Any = None,
|
||||
required: bool = True,
|
||||
options: Optional[list[str]] = None,
|
||||
min_value: Optional[float] = None,
|
||||
max_value: Optional[float] = None,
|
||||
items: Optional[dict] = None,
|
||||
properties: Optional[dict] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.label = label
|
||||
self.description = description
|
||||
self.default = default
|
||||
self.required = required
|
||||
self.options = options
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.items = items
|
||||
self.properties = properties
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
"label": self.label,
|
||||
"description": self.description,
|
||||
"required": self.required,
|
||||
}
|
||||
|
||||
if self.default is not None:
|
||||
result["default"] = self.default
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.min_value is not None:
|
||||
result["minValue"] = self.min_value
|
||||
|
||||
if self.max_value is not None:
|
||||
result["maxValue"] = self.max_value
|
||||
|
||||
if self.items is not None:
|
||||
result["items"] = self.items
|
||||
|
||||
if self.properties is not None:
|
||||
result["properties"] = self.properties
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConfigSchemaGenerator:
|
||||
"""配置架构生成器"""
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_description(config_class: type, field_name: str) -> str:
|
||||
"""
|
||||
从类定义中提取字段的文档字符串描述
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
field_name: 字段名
|
||||
|
||||
Returns:
|
||||
str: 字段描述
|
||||
"""
|
||||
try:
|
||||
# 获取源代码
|
||||
source = inspect.getsource(config_class)
|
||||
lines = source.split("\n")
|
||||
|
||||
# 查找字段定义
|
||||
field_found = False
|
||||
description_lines = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 匹配字段定义行,例如: platform: str
|
||||
if f"{field_name}:" in line and "=" in line:
|
||||
field_found = True
|
||||
# 查找下一行的文档字符串
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
elif f"{field_name}:" in line and "=" not in line:
|
||||
# 没有默认值的字段
|
||||
field_found = True
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
|
||||
if field_found and description_lines:
|
||||
return " ".join(description_lines)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
|
||||
"""
|
||||
获取字段类型和选项
|
||||
|
||||
Args:
|
||||
field_type: 字段类型
|
||||
|
||||
Returns:
|
||||
tuple: (FieldType, options, items)
|
||||
"""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# 处理 Literal 类型(枚举选项)
|
||||
if origin is Literal:
|
||||
return FieldType.SELECT, [str(arg) for arg in args], None
|
||||
|
||||
# 处理 list 类型
|
||||
if origin is list:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
elif item_type is int:
|
||||
items = {"type": "integer"}
|
||||
elif item_type is float:
|
||||
items = {"type": "number"}
|
||||
elif item_type is bool:
|
||||
items = {"type": "boolean"}
|
||||
elif item_type is dict:
|
||||
items = {"type": "object"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理 set 类型(与 list 类似)
|
||||
if origin is set:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
|
||||
# 默认为字符串
|
||||
return FieldType.STRING, None, None
|
||||
|
||||
@staticmethod
|
||||
def _format_field_name(name: str) -> str:
|
||||
"""
|
||||
格式化字段名为可读的标签
|
||||
|
||||
Args:
|
||||
name: 原始字段名
|
||||
|
||||
Returns:
|
||||
str: 格式化后的标签
|
||||
"""
|
||||
# 将下划线替换为空格,并首字母大写
|
||||
return " ".join(word.capitalize() for word in name.split("_"))
|
||||
|
||||
@staticmethod
|
||||
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
|
||||
"""
|
||||
从配置类生成前端表单架构
|
||||
|
||||
Args:
|
||||
config_class: 配置类(必须继承自 ConfigBase)
|
||||
include_nested: 是否包含嵌套的配置对象
|
||||
|
||||
Returns:
|
||||
dict: 前端表单架构
|
||||
"""
|
||||
if not issubclass(config_class, ConfigBase):
|
||||
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
|
||||
|
||||
schema_fields = []
|
||||
nested_schemas = {}
|
||||
|
||||
for field in fields(config_class):
|
||||
# 跳过私有字段和内部字段
|
||||
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
|
||||
continue
|
||||
|
||||
# 提取字段描述
|
||||
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
|
||||
|
||||
# 判断是否必填
|
||||
required = field.default is MISSING and field.default_factory is MISSING
|
||||
|
||||
# 获取默认值
|
||||
default_value = None
|
||||
if field.default is not MISSING:
|
||||
default_value = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
try:
|
||||
default_value = field.default_factory()
|
||||
except Exception:
|
||||
default_value = None
|
||||
|
||||
# 检查是否为嵌套的 ConfigBase
|
||||
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
|
||||
if include_nested:
|
||||
# 递归生成嵌套配置的架构
|
||||
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
|
||||
nested_schemas[field.name] = nested_schema
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=FieldType.OBJECT,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description or field.type.__doc__ or "",
|
||||
default=default_value,
|
||||
required=required,
|
||||
properties=nested_schema,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
# 获取字段类型和选项
|
||||
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
|
||||
|
||||
# 特殊处理:长文本使用 textarea
|
||||
if field_type == FieldType.STRING and field.name in [
|
||||
"personality",
|
||||
"reply_style",
|
||||
"interest",
|
||||
"plan_style",
|
||||
"visual_style",
|
||||
"private_plan_style",
|
||||
"reaction",
|
||||
"filtration_prompt",
|
||||
]:
|
||||
field_type = FieldType.TEXTAREA
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=field_type,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description,
|
||||
default=default_value,
|
||||
required=required,
|
||||
options=options,
|
||||
items=items,
|
||||
)
|
||||
|
||||
schema_fields.append(field_schema.to_dict())
|
||||
|
||||
return {
|
||||
"className": config_class.__name__,
|
||||
"classDoc": config_class.__doc__ or "",
|
||||
"fields": schema_fields,
|
||||
"nested": nested_schemas if nested_schemas else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
|
||||
"""
|
||||
生成完整的配置架构(包含所有嵌套的子配置)
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
|
||||
Returns:
|
||||
dict: 完整的配置架构
|
||||
"""
|
||||
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,790 +0,0 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
|
||||
|
||||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[ExpressionResponse]
|
||||
|
||||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
chat_id: str
|
||||
|
||||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
checked: Optional[bool] = None
|
||||
rejected: Optional[bool] = None
|
||||
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
|
||||
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
|
||||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
return ExpressionResponse(
|
||||
id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
# 优先使用群聊名称,否则使用用户昵称
|
||||
if chat_stream.group_name:
|
||||
return chat_stream.group_name
|
||||
elif chat_stream.user_nickname:
|
||||
return chat_stream.user_nickname
|
||||
return chat_id # 找不到时返回原始ID
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称"""
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
||||
for cs in chat_streams:
|
||||
if cs.group_name:
|
||||
result[cs.stream_id] = cs.group_name
|
||||
elif cs.user_nickname:
|
||||
result[cs.stream_id] = cs.user_nickname
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
|
||||
|
||||
class ChatInfo(BaseModel):
|
||||
"""聊天信息"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool
|
||||
data: List[ChatInfo]
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=cs.stream_id,
|
||||
chat_name=chat_name,
|
||||
platform=cs.platform,
|
||||
is_group=bool(cs.group_id),
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/list", response_model=ExpressionListResponse)
|
||||
async def get_expression_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"创建表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 冲突检测:如果要求未检查状态,但已经被检查了
|
||||
if request.require_unchecked and expression.checked:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表"
|
||||
)
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# 移除 require_unchecked,它不是数据库字段
|
||||
update_data.pop('require_unchecked', None)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 如果更新了 checked 或 rejected,标记为用户修改
|
||||
if 'checked' in update_data or 'rejected' in update_data:
|
||||
update_data['modified_by'] = 'user'
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
expression.save()
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int]
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
# 查找所有要删除的表达方式
|
||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||
found_ids = [expr.id for expr in expressions]
|
||||
|
||||
# 检查是否有未找到的ID
|
||||
not_found_ids = set(request.ids) - set(found_ids)
|
||||
if not_found_ids:
|
||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||
|
||||
# 执行批量删除
|
||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||
|
||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ============ 审核相关接口 ============
|
||||
|
||||
class ReviewStatsResponse(BaseModel):
|
||||
"""审核统计响应"""
|
||||
total: int
|
||||
unchecked: int
|
||||
passed: int
|
||||
rejected: int
|
||||
ai_checked: int
|
||||
user_checked: int
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取审核统计数据
|
||||
|
||||
Returns:
|
||||
审核统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
unchecked = Expression.select().where(Expression.checked == False).count()
|
||||
passed = Expression.select().where(
|
||||
(Expression.checked == True) & (Expression.rejected == False)
|
||||
).count()
|
||||
rejected = Expression.select().where(
|
||||
(Expression.checked == True) & (Expression.rejected == True)
|
||||
).count()
|
||||
ai_checked = Expression.select().where(Expression.modified_by == 'ai').count()
|
||||
user_checked = Expression.select().where(Expression.modified_by == 'user').count()
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
unchecked=unchecked,
|
||||
passed=passed,
|
||||
rejected=rejected,
|
||||
ai_checked=ai_checked,
|
||||
user_checked=user_checked
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取审核统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
class ReviewListResponse(BaseModel):
|
||||
"""审核列表响应"""
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[ExpressionResponse]
|
||||
|
||||
|
||||
@router.get("/review/list", response_model=ReviewListResponse)
|
||||
async def get_review_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取待审核/已审核的表达方式列表
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
filter_type: 筛选类型 (unchecked/passed/rejected/all)
|
||||
search: 搜索关键词
|
||||
chat_id: 聊天ID筛选
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
query = Expression.select()
|
||||
|
||||
# 根据筛选类型过滤
|
||||
if filter_type == "unchecked":
|
||||
query = query.where(Expression.checked == False)
|
||||
elif filter_type == "passed":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == False))
|
||||
elif filter_type == "rejected":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == True))
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search)) | (Expression.style.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
from peewee import Case
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.create_date.is_null(), 1)], 0),
|
||||
Expression.create_date.desc()
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=[expression_to_response(expr) for expr in expressions]
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取审核列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchReviewItem(BaseModel):
|
||||
"""批量审核项"""
|
||||
id: int
|
||||
rejected: bool
|
||||
require_unchecked: bool = True # 默认要求未检查状态
|
||||
|
||||
|
||||
class BatchReviewRequest(BaseModel):
|
||||
"""批量审核请求"""
|
||||
items: List[BatchReviewItem]
|
||||
|
||||
|
||||
class BatchReviewResultItem(BaseModel):
|
||||
"""批量审核结果项"""
|
||||
id: int
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchReviewResponse(BaseModel):
|
||||
"""批量审核响应"""
|
||||
success: bool
|
||||
total: int
|
||||
succeeded: int
|
||||
failed: int
|
||||
results: List[BatchReviewResultItem]
|
||||
|
||||
|
||||
@router.post("/review/batch", response_model=BatchReviewResponse)
|
||||
async def batch_review_expressions(
|
||||
request: BatchReviewRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量审核表达方式
|
||||
|
||||
Args:
|
||||
request: 批量审核请求
|
||||
|
||||
Returns:
|
||||
批量审核结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.items:
|
||||
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
for item in request.items:
|
||||
try:
|
||||
expression = Expression.get_or_none(Expression.id == item.id)
|
||||
|
||||
if not expression:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=f"未找到 ID 为 {item.id} 的表达方式"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked and expression.checked:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 更新状态
|
||||
expression.checked = True
|
||||
expression.rejected = item.rejected
|
||||
expression.modified_by = 'user'
|
||||
expression.last_active_time = time.time()
|
||||
expression.save()
|
||||
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=True,
|
||||
message="通过" if not item.rejected else "拒绝"
|
||||
))
|
||||
succeeded += 1
|
||||
|
||||
except Exception as e:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=str(e)
|
||||
))
|
||||
failed += 1
|
||||
|
||||
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
|
||||
|
||||
return BatchReviewResponse(
|
||||
success=True,
|
||||
total=len(request.items),
|
||||
succeeded=succeeded,
|
||||
failed=failed,
|
||||
results=results
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量审核失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e
|
||||
@@ -1,662 +0,0 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.git_mirror")
|
||||
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
_update_progress = callback
|
||||
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
|
||||
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
|
||||
GITHUB = "github" # GitHub 官方源(兜底)
|
||||
CUSTOM = "custom" # 自定义镜像源
|
||||
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
"id": "gh-proxy",
|
||||
"name": "gh-proxy 镜像",
|
||||
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
"name": "gh-proxy 香港节点",
|
||||
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
"name": "gh-proxy CDN 节点",
|
||||
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
"name": "gh-proxy EdgeOne 节点",
|
||||
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
"name": "Meyzh GitHub 镜像",
|
||||
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
"name": "GitHub 官方源(兜底)",
|
||||
"raw_prefix": "https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
self._init_default_mirrors()
|
||||
else:
|
||||
self.mirrors = data["git_mirrors"]
|
||||
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
|
||||
else:
|
||||
logger.info("配置文件不存在,创建默认配置")
|
||||
self._init_default_mirrors()
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: str,
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
"raw_prefix": raw_prefix,
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: Optional[str] = None,
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
if name is not None:
|
||||
mirror["name"] = name
|
||||
if raw_prefix is not None:
|
||||
mirror["raw_prefix"] = raw_prefix
|
||||
if clone_prefix is not None:
|
||||
mirror["clone_prefix"] = clone_prefix
|
||||
if enabled is not None:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
return None
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
for i, mirror in enumerate(self.mirrors):
|
||||
if mirror.get("id") == mirror_id:
|
||||
self.mirrors.pop(i)
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
return [m["id"] for m in enabled]
|
||||
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
config: 镜像源配置管理器(可选,默认创建新实例)
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
- version: str - Git 版本号(如果已安装)
|
||||
- path: str - Git 可执行文件路径(如果已安装)
|
||||
- error: str - 错误信息(如果未安装或检测失败)
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
branch: 分支名称
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- data: str - 文件内容(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
if _update_progress:
|
||||
try:
|
||||
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
logger.debug(f"尝试 #{attempt + 1}: {url}")
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = f"请求超时: {e}"
|
||||
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
target_path: 目标路径
|
||||
branch: 分支名称(可选)
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- path: str - 克隆路径(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
async def _clone_with_url(
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(target_path),
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_git_mirror_service: Optional[GitMirrorService] = None
|
||||
|
||||
|
||||
def get_git_mirror_service() -> GitMirrorService:
|
||||
"""获取 Git 镜像源服务实例(单例)"""
|
||||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
@@ -1,532 +0,0 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
if isinstance(parsed, list):
|
||||
# 格式: [["stream_id", user_id], ...]
|
||||
stream_ids = []
|
||||
for item in parsed:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
stream_ids.append(str(item[0]))
|
||||
return stream_ids
|
||||
else:
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 不是有效的 JSON,可能是直接的 stream_id
|
||||
return [chat_id_str]
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream and chat_stream.group_name:
|
||||
names.append(chat_stream.group_name)
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
|
||||
class JargonResponse(BaseModel):
|
||||
"""黑话信息响应"""
|
||||
|
||||
id: int
|
||||
content: str
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: str
|
||||
stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配
|
||||
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
|
||||
is_global: bool = False
|
||||
count: int = 0
|
||||
is_jargon: Optional[bool] = None
|
||||
is_complete: bool = False
|
||||
inference_with_context: Optional[str] = None
|
||||
inference_content_only: Optional[str] = None
|
||||
|
||||
|
||||
class JargonListResponse(BaseModel):
|
||||
"""黑话列表响应"""
|
||||
|
||||
success: bool = True
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[JargonResponse]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
"""黑话详情响应"""
|
||||
|
||||
success: bool = True
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonCreateRequest(BaseModel):
|
||||
"""黑话创建请求"""
|
||||
|
||||
content: str = Field(..., description="黑话内容")
|
||||
raw_content: Optional[str] = Field(None, description="原始内容")
|
||||
meaning: Optional[str] = Field(None, description="含义")
|
||||
chat_id: str = Field(..., description="聊天ID")
|
||||
is_global: bool = Field(False, description="是否全局")
|
||||
|
||||
|
||||
class JargonUpdateRequest(BaseModel):
|
||||
"""黑话更新请求"""
|
||||
|
||||
content: Optional[str] = None
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
is_global: Optional[bool] = None
|
||||
is_jargon: Optional[bool] = None
|
||||
|
||||
|
||||
class JargonCreateResponse(BaseModel):
|
||||
"""黑话创建响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonUpdateResponse(BaseModel):
|
||||
"""黑话更新响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: Optional[JargonResponse] = None
|
||||
|
||||
|
||||
class JargonDeleteResponse(BaseModel):
|
||||
"""黑话删除响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
deleted_count: int = 0
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int] = Field(..., description="要删除的黑话ID列表")
|
||||
|
||||
|
||||
class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
"""聊天信息响应"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool = True
|
||||
data: List[ChatInfoResponse]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
# 解析 chat_id 获取显示名称和 stream_id
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": jargon.chat_id,
|
||||
"stream_id": stream_id,
|
||||
"chat_name": chat_name,
|
||||
"is_global": jargon.is_global,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
"is_complete": jargon.is_complete,
|
||||
"inference_with_context": jargon.inference_with_context,
|
||||
"inference_content_only": jargon.inference_content_only,
|
||||
}
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
|
||||
@router.get("/list", response_model=JargonListResponse)
|
||||
async def get_jargon_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = Jargon.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Jargon.content.contains(search))
|
||||
| (Jargon.meaning.contains(search))
|
||||
| (Jargon.raw_content.contains(search))
|
||||
)
|
||||
|
||||
# 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式)
|
||||
if chat_id:
|
||||
# 从传入的 chat_id 中解析出 stream_id
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
# 使用第一个 stream_id 进行模糊匹配
|
||||
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
|
||||
else:
|
||||
# 如果无法解析,使用精确匹配
|
||||
query = query.where(Jargon.chat_id == chat_id)
|
||||
|
||||
# 按是否是黑话筛选
|
||||
if is_jargon is not None:
|
||||
query = query.where(Jargon.is_jargon == is_jargon)
|
||||
|
||||
# 按是否全局筛选
|
||||
if is_global is not None:
|
||||
query = query.where(Jargon.is_global == is_global)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页和排序(按使用次数降序)
|
||||
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
|
||||
query = query.paginate(page, page_size)
|
||||
|
||||
# 转换为响应格式
|
||||
data = [jargon_to_dict(j) for j in query]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
|
||||
result = []
|
||||
for stream_id in seen_stream_ids:
|
||||
# 尝试从 ChatStreams 表获取聊天名称
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id,方便筛选匹配
|
||||
chat_name=chat_stream.group_name or stream_id,
|
||||
platform=chat_stream.platform,
|
||||
is_group=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id
|
||||
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatListResponse(success=True, data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
# 总数量
|
||||
total = Jargon.select().count()
|
||||
|
||||
# 已确认是黑话的数量
|
||||
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
|
||||
|
||||
# 已确认不是黑话的数量
|
||||
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
|
||||
|
||||
# 未判定的数量
|
||||
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
|
||||
|
||||
# 全局黑话数量
|
||||
global_count = Jargon.select().where(Jargon.is_global).count()
|
||||
|
||||
# 已完成推断的数量
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
|
||||
.group_by(Jargon.chat_id)
|
||||
.order_by(fn.COUNT(Jargon.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
data={
|
||||
"total": total,
|
||||
"confirmed_jargon": confirmed_jargon,
|
||||
"confirmed_not_jargon": confirmed_not_jargon,
|
||||
"pending": pending,
|
||||
"global_count": global_count,
|
||||
"complete_count": complete_count,
|
||||
"chat_count": chat_count,
|
||||
"top_chats": top_chats_dict,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
# 创建黑话
|
||||
jargon = Jargon.create(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning,
|
||||
chat_id=request.chat_id,
|
||||
is_global=request.is_global,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
|
||||
return JargonCreateResponse(
|
||||
success=True,
|
||||
message="创建成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
# 增量更新字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
jargon.save()
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message="更新成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
content = jargon.content
|
||||
jargon.delete_instance()
|
||||
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message="删除成功",
|
||||
deleted_count=1,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
|
||||
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除 {deleted_count} 条黑话",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {updated_count} 条黑话状态",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新黑话状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e
|
||||
@@ -1,298 +0,0 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
id: str
|
||||
type: str # 'entity' or 'paragraph'
|
||||
content: str
|
||||
create_time: Optional[float] = None
|
||||
|
||||
|
||||
class KnowledgeEdge(BaseModel):
|
||||
"""知识边"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
weight: float
|
||||
create_time: Optional[float] = None
|
||||
update_time: Optional[float] = None
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""知识图谱"""
|
||||
|
||||
nodes: List[KnowledgeNode]
|
||||
edges: List[KnowledgeEdge]
|
||||
|
||||
|
||||
class KnowledgeStats(BaseModel):
|
||||
"""知识库统计信息"""
|
||||
|
||||
total_nodes: int
|
||||
total_edges: int
|
||||
entity_nodes: int
|
||||
paragraph_nodes: int
|
||||
avg_connections: float
|
||||
|
||||
|
||||
def _load_kg_manager():
|
||||
"""延迟加载 KGManager"""
|
||||
try:
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
|
||||
kg_manager = KGManager()
|
||||
kg_manager.load_from_file()
|
||||
return kg_manager
|
||||
except Exception as e:
|
||||
logger.error(f"加载 KGManager 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
"""将 DiGraph 转换为 JSON 格式"""
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
graph = kg_manager.graph
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# 转换节点
|
||||
node_list = graph.get_node_list()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
# 转换边
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
try:
|
||||
# edge_tuple 是 (source, target) 元组
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 通过 graph[source, target] 获取边的属性数据
|
||||
edge_data = graph[source, target]
|
||||
|
||||
# edge_data 支持 [] 操作符但不支持 .get()
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.get("/graph", response_model=KnowledgeGraph)
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
Args:
|
||||
limit: 返回的最大节点数,默认 100,最大 10000
|
||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None:
|
||||
logger.warning("KGManager 未初始化,返回空图谱")
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
graph = kg_manager.graph
|
||||
all_node_list = graph.get_node_list()
|
||||
|
||||
# 按类型过滤节点
|
||||
if node_type == "entity":
|
||||
all_node_list = [
|
||||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||
]
|
||||
elif node_type == "paragraph":
|
||||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||
|
||||
# 限制节点数量
|
||||
total_nodes = len(all_node_list)
|
||||
if len(all_node_list) > limit:
|
||||
node_list = all_node_list[:limit]
|
||||
else:
|
||||
node_list = all_node_list
|
||||
|
||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||
|
||||
# 转换节点
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
node_ids.add(node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||
edges = []
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
try:
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 只包含两端都在当前节点集中的边
|
||||
if source not in node_ids or target not in node_ids:
|
||||
continue
|
||||
|
||||
edge_data = graph[source, target]
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
return graph_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
"""获取知识库统计信息
|
||||
|
||||
Returns:
|
||||
KnowledgeStats: 统计信息
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
edge_list = graph.get_edge_list()
|
||||
|
||||
total_nodes = len(node_list)
|
||||
total_edges = len(edge_list)
|
||||
|
||||
# 统计节点类型
|
||||
entity_nodes = 0
|
||||
paragraph_nodes = 0
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||
if node_type == "ent":
|
||||
entity_nodes += 1
|
||||
elif node_type == "pg":
|
||||
paragraph_nodes += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 计算平均连接数
|
||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||
|
||||
return KnowledgeStats(
|
||||
total_nodes=total_nodes,
|
||||
total_edges=total_edges,
|
||||
entity_nodes=entity_nodes,
|
||||
paragraph_nodes=paragraph_nodes,
|
||||
avg_connections=round(avg_connections, 2),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||
"""搜索知识节点
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
Returns:
|
||||
List[KnowledgeNode]: 匹配的节点列表
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return []
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
# 在节点内容中搜索
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||
return results[:50] # 限制返回数量
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||
return []
|
||||
@@ -1,177 +0,0 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志条数
|
||||
|
||||
Returns:
|
||||
日志列表
|
||||
"""
|
||||
logs = []
|
||||
log_dir = Path("logs")
|
||||
|
||||
if not log_dir.exists():
|
||||
return logs
|
||||
|
||||
# 获取所有日志文件,按修改时间排序
|
||||
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
# 用于生成唯一 ID 的计数器
|
||||
log_counter = 0
|
||||
|
||||
# 从最新的文件开始读取
|
||||
for log_file in log_files:
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# 从文件末尾开始读取
|
||||
for line in reversed(lines):
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
try:
|
||||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
"level": log_entry.get("level", "INFO").upper(),
|
||||
"module": log_entry.get("logger_name", ""),
|
||||
"message": log_entry.get("event", ""),
|
||||
}
|
||||
logs.append(formatted_log)
|
||||
log_counter += 1
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"读取日志文件失败 {log_file}: {e}")
|
||||
continue
|
||||
|
||||
# 反转列表,使其按时间顺序排列(旧到新)
|
||||
return list(reversed(logs))
|
||||
|
||||
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
recent_logs = load_recent_logs(limit=100)
|
||||
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
|
||||
|
||||
for log_entry in recent_logs:
|
||||
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"发送历史日志失败: {e}")
|
||||
|
||||
try:
|
||||
# 保持连接,等待客户端消息或断开
|
||||
while True:
|
||||
# 接收客户端消息(用于心跳或控制指令)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 可以处理客户端的控制消息,例如:
|
||||
# - "ping" -> 心跳检测
|
||||
# - {"filter": "ERROR"} -> 设置日志级别过滤
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
@@ -1,383 +0,0 @@
|
||||
"""
|
||||
模型列表获取API路由
|
||||
|
||||
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
"openai": {
|
||||
"endpoint": "/models",
|
||||
"parser": "openai",
|
||||
},
|
||||
# Gemini 格式
|
||||
"gemini": {
|
||||
"endpoint": "/models",
|
||||
"parser": "gemini",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
if not url:
|
||||
return ""
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "models" in data and isinstance(data["models"], list):
|
||||
for model in data["models"]:
|
||||
if isinstance(model, dict) and "name" in model:
|
||||
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
|
||||
model_id = model["name"]
|
||||
if model_id.startswith("models/"):
|
||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
async def _fetch_models_from_provider(
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
parser: str,
|
||||
client_type: str = "openai",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||
if e.response.status_code == 401:
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||
elif e.response.status_code == 403:
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||
elif e.response.status_code == 404:
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||
|
||||
# 根据解析器类型解析响应
|
||||
if parser == "openai":
|
||||
return _parse_openai_response(data)
|
||||
elif parser == "gemini":
|
||||
return _parse_gemini_response(data)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
|
||||
|
||||
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def get_provider_models(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
if not provider_config:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key")
|
||||
client_type = provider_config.get("client_type", "openai")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"provider": provider_name,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/list-by-url")
|
||||
async def get_models_by_url(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: str = Query(..., description="API Key"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
"""
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-connection")
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
||||
分两步测试:
|
||||
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||
|
||||
返回:
|
||||
- network_ok: 网络是否连通
|
||||
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||
- latency_ms: 响应延迟(毫秒)
|
||||
- error: 错误信息(如果有)
|
||||
"""
|
||||
import time
|
||||
|
||||
base_url = _normalize_url(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
"latency_ms": None,
|
||||
"error": None,
|
||||
"http_status": None,
|
||||
}
|
||||
|
||||
# 第一步:测试网络连通性
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||
response = await client.get(base_url)
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
result["network_ok"] = True
|
||||
result["latency_ms"] = round(latency, 2)
|
||||
result["http_status"] = response.status_code
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||
return result
|
||||
except httpx.TimeoutException:
|
||||
result["error"] = "连接超时:服务器响应时间过长"
|
||||
return result
|
||||
except httpx.RequestError as e:
|
||||
result["error"] = f"请求错误:{str(e)}"
|
||||
return result
|
||||
except Exception as e:
|
||||
result["error"] = f"未知错误:{str(e)}"
|
||||
return result
|
||||
|
||||
# 第二步:如果提供了 API Key,验证其有效性
|
||||
if api_key:
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
elif response.status_code in (401, 403):
|
||||
result["api_key_valid"] = False
|
||||
result["error"] = "API Key 无效或已过期"
|
||||
else:
|
||||
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||
result["api_key_valid"] = None
|
||||
|
||||
except Exception as e:
|
||||
# API Key 验证失败不影响网络连通性结果
|
||||
logger.warning(f"API Key 验证失败: {e}")
|
||||
result["api_key_valid"] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
"""
|
||||
# 读取配置文件
|
||||
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(model_config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
@@ -1,416 +0,0 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import json
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.person")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/person", tags=["Person"])
|
||||
|
||||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
"""人物信息响应"""
|
||||
|
||||
id: int
|
||||
is_known: bool
|
||||
person_id: str
|
||||
person_name: Optional[str]
|
||||
name_reason: Optional[str]
|
||||
platform: str
|
||||
user_id: str
|
||||
nickname: Optional[str]
|
||||
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
|
||||
memory_points: Optional[str]
|
||||
know_times: Optional[float]
|
||||
know_since: Optional[float]
|
||||
last_know: Optional[float]
|
||||
|
||||
|
||||
class PersonListResponse(BaseModel):
|
||||
"""人物列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[PersonInfoResponse]
|
||||
|
||||
|
||||
class PersonDetailResponse(BaseModel):
|
||||
"""人物详情响应"""
|
||||
|
||||
success: bool
|
||||
data: PersonInfoResponse
|
||||
|
||||
|
||||
class PersonUpdateRequest(BaseModel):
|
||||
"""人物信息更新请求"""
|
||||
|
||||
person_name: Optional[str] = None
|
||||
name_reason: Optional[str] = None
|
||||
nickname: Optional[str] = None
|
||||
memory_points: Optional[str] = None
|
||||
is_known: Optional[bool] = None
|
||||
|
||||
|
||||
class PersonUpdateResponse(BaseModel):
|
||||
"""人物信息更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[PersonInfoResponse] = None
|
||||
|
||||
|
||||
class PersonDeleteResponse(BaseModel):
|
||||
"""人物删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
person_ids: List[str]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
"""解析群昵称 JSON 字符串"""
|
||||
if not group_nick_name_str:
|
||||
return None
|
||||
try:
|
||||
return json.loads(group_nick_name_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
||||
"""将 PersonInfo 模型转换为响应对象"""
|
||||
return PersonInfoResponse(
|
||||
id=person.id,
|
||||
is_known=person.is_known,
|
||||
person_id=person.person_id,
|
||||
person_name=person.person_name,
|
||||
name_reason=person.name_reason,
|
||||
platform=person.platform,
|
||||
user_id=person.user_id,
|
||||
nickname=person.nickname,
|
||||
group_nick_name=parse_group_nick_name(person.group_nick_name),
|
||||
memory_points=person.memory_points,
|
||||
know_times=person.know_times,
|
||||
know_since=person.know_since,
|
||||
last_know=person.last_know,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=PersonListResponse)
|
||||
async def get_person_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 person_name, nickname, user_id)
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 已认识状态过滤
|
||||
if is_known is not None:
|
||||
query = query.where(PersonInfo.is_known == is_known)
|
||||
|
||||
# 平台过滤
|
||||
if platform:
|
||||
query = query.where(PersonInfo.platform == platform)
|
||||
|
||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
persons = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取人物列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取人物详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 更新最后修改时间
|
||||
update_data["last_know"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(person, field, value)
|
||||
|
||||
person.save()
|
||||
|
||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.nickname or person.user_id
|
||||
|
||||
# 执行删除
|
||||
person.delete_instance()
|
||||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
|
||||
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
unknown = total - known
|
||||
|
||||
# 按平台统计
|
||||
platforms = {}
|
||||
for person in PersonInfo.select(PersonInfo.platform):
|
||||
platform = person.platform
|
||||
platforms[platform] = platforms.get(platform, 0) + 1
|
||||
|
||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
Args:
|
||||
request: 包含person_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
for person_id in request.person_ids:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
person.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除: {person_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除 {person_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
|
||||
message = f"成功删除 {deleted_count} 个人物"
|
||||
if failed_count > 0:
|
||||
message += f",{failed_count} 个失败"
|
||||
|
||||
return BatchDeleteResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
deleted_count=deleted_count,
|
||||
failed_count=failed_count,
|
||||
failed_ids=failed_ids,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
@@ -1,164 +0,0 @@
|
||||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
# 当前加载进度状态
|
||||
current_progress: Dict[str, Any] = {
|
||||
"operation": "idle", # idle, fetch, install, uninstall, update
|
||||
"stage": "idle", # idle, loading, success, error
|
||||
"progress": 0, # 0-100
|
||||
"message": "",
|
||||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
"""广播进度更新到所有连接的客户端"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
# 移除断开的连接
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def update_progress(
|
||||
stage: str,
|
||||
progress: int,
|
||||
message: str,
|
||||
operation: str = "fetch",
|
||||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
Args:
|
||||
stage: 阶段 (idle, loading, success, error)
|
||||
progress: 进度百分比 (0-100)
|
||||
message: 当前消息
|
||||
operation: 操作类型 (fetch, install, uninstall, update)
|
||||
error: 错误信息(可选)
|
||||
plugin_id: 当前操作的插件 ID
|
||||
total_plugins: 总插件数
|
||||
loaded_plugins: 已加载插件数
|
||||
"""
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"error": error,
|
||||
"plugin_id": plugin_id,
|
||||
"total_plugins": total_plugins,
|
||||
"loaded_plugins": loaded_plugins,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
await broadcast_progress(progress_data)
|
||||
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
|
||||
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
|
||||
|
||||
# 保持连接并处理客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 处理客户端心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取插件进度 WebSocket 路由器"""
|
||||
return router
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,245 +0,0 @@
|
||||
"""
|
||||
WebUI 请求频率限制模块
|
||||
防止暴力破解和 API 滥用
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
return False, None
|
||||
|
||||
def check_rate_limit(
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
def record_failed_attempt(
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
if key in self._requests:
|
||||
del self._requests[key]
|
||||
|
||||
|
||||
# 全局单例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取 RateLimiter 单例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
系统控制路由
|
||||
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""状态响应"""
|
||||
|
||||
running: bool
|
||||
uptime: float
|
||||
version: str
|
||||
start_time: str
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
请求重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# 记录重启操作
|
||||
logger.info("WebUI 触发重启操作")
|
||||
|
||||
# 定义延迟重启的异步任务
|
||||
async def delayed_restart():
|
||||
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
|
||||
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
|
||||
# 42 是约定的重启状态码
|
||||
logger.info("WebUI 请求重启,退出代码 42")
|
||||
os._exit(42)
|
||||
|
||||
# 创建后台任务执行重启
|
||||
asyncio.create_task(delayed_restart())
|
||||
|
||||
# 立即返回成功响应
|
||||
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
返回麦麦的运行状态、运行时长和版本信息。
|
||||
"""
|
||||
try:
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
# 尝试获取版本信息(需要根据实际情况调整)
|
||||
version = MMC_VERSION # 可以从配置或常量中读取
|
||||
|
||||
return StatusResponse(
|
||||
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
仅重新加载配置文件,某些配置可能需要重启才能生效。
|
||||
此功能需要在主程序中实现配置热重载逻辑。
|
||||
"""
|
||||
# 这里需要调用主程序的配置重载函数
|
||||
# 示例:await app_instance.reload_config()
|
||||
|
||||
return {"success": True, "message": "配置重载功能待实现"}
|
||||
@@ -1,456 +0,0 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
from .expression_routes import router as expression_router
|
||||
from .jargon_routes import router as jargon_router
|
||||
from .emoji_routes import router as emoji_router
|
||||
from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
from .model_routes import router as model_router
|
||||
from .ws_auth import router as ws_auth_router
|
||||
from .annual_report_routes import router as annual_report_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
|
||||
|
||||
# 注册配置管理路由
|
||||
router.include_router(config_router)
|
||||
# 注册统计数据路由
|
||||
router.include_router(statistics_router)
|
||||
# 注册人物信息管理路由
|
||||
router.include_router(person_router)
|
||||
# 注册表达方式管理路由
|
||||
router.include_router(expression_router)
|
||||
# 注册黑话管理路由
|
||||
router.include_router(jargon_router)
|
||||
# 注册表情包管理路由
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册年度报告路由
|
||||
router.include_router(annual_report_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
is_first_setup: bool = Field(False, description="是否为首次设置")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
|
||||
|
||||
class FirstSetupStatusResponse(BaseModel):
|
||||
"""首次配置状态响应"""
|
||||
|
||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||
message: str = Field(..., description="状态消息")
|
||||
|
||||
|
||||
class CompleteSetupResponse(BaseModel):
|
||||
"""完成配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
class ResetSetupResponse(BaseModel):
|
||||
"""重置配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "service": "MaiBot WebUI"}
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(
|
||||
request_body: TokenVerifyRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
):
|
||||
"""
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request_body: 包含 token 的验证请求
|
||||
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
验证结果(包含首次配置状态)
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
rate_limiter = get_rate_limiter()
|
||||
|
||||
is_valid = token_manager.verify_token(request_body.token)
|
||||
|
||||
if is_valid:
|
||||
# 认证成功,重置失败计数
|
||||
rate_limiter.reset_failures(request)
|
||||
# 设置 HttpOnly Cookie(传入 request 以检测协议)
|
||||
set_auth_cookie(response, request_body.token, request)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
# 记录失败尝试
|
||||
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||
request,
|
||||
max_failures=5, # 5 次失败
|
||||
window_seconds=300, # 5 分钟窗口
|
||||
block_duration=600, # 封禁 10 分钟
|
||||
)
|
||||
|
||||
if blocked:
|
||||
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
|
||||
|
||||
message = "Token 无效或已过期"
|
||||
if remaining <= 2:
|
||||
message += f"(剩余 {remaining} 次尝试机会)"
|
||||
|
||||
return TokenVerifyResponse(valid=False, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
clear_auth_cookie(response)
|
||||
return {"success": True, "message": "已成功登出"}
|
||||
|
||||
|
||||
@router.get("/auth/check")
|
||||
async def check_auth_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 记录请求信息用于调试
|
||||
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
logger.debug("使用 Cookie 中的 token")
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
logger.debug("使用 Header 中的 token")
|
||||
|
||||
if not token:
|
||||
logger.debug("未找到 token,返回未认证")
|
||||
return {"authenticated": False}
|
||||
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(token)
|
||||
logger.debug(f"Token 验证结果: {is_valid}")
|
||||
|
||||
if is_valid:
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
except Exception as e:
|
||||
logger.error(f"认证检查失败: {e}", exc_info=True)
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
response: Response,
|
||||
req: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
# 如果更新成功,清除 Cookie,要求用户重新登录
|
||||
if success:
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 更新失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 更新失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
async def regenerate_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
# 清除 Cookie,要求用户重新登录
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 重新生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"标记配置完成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="重置配置状态失败") from e
|
||||
@@ -1,319 +0,0 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
online_time: float = Field(0.0, description="在线时间(秒)")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
avg_response_time: float = Field(0.0, description="平均响应时间")
|
||||
cost_per_hour: float = Field(0.0, description="每小时花费")
|
||||
tokens_per_hour: float = Field(0.0, description="每小时token数")
|
||||
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
total_tokens: int
|
||||
avg_response_time: float
|
||||
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
tokens: int = 0
|
||||
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
daily_data: List[TimeSeriesData]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||
summary = StatisticsSummary()
|
||||
|
||||
# 使用聚合查询替代全量加载
|
||||
query = LLMUsage.select(
|
||||
fn.COUNT(LLMUsage.id).alias("total_requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
|
||||
result = query.dicts().get()
|
||||
summary.total_requests = result["total_requests"]
|
||||
summary.total_cost = result["total_cost"]
|
||||
summary.total_tokens = result["total_tokens"]
|
||||
summary.avg_response_time = result["avg_response_time"] or 0.0
|
||||
|
||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||
online_records = list(
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||
)
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
# 查询消息数量 - 使用聚合优化
|
||||
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
|
||||
)
|
||||
summary.total_messages = messages_query.scalar() or 0
|
||||
|
||||
# 统计回复数量
|
||||
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp())
|
||||
& (Messages.time <= end_time.timestamp())
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
summary.total_replies = replies_query.scalar() or 0
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
|
||||
fn.COUNT(LLMUsage.id).alias("request_count"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
)
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(10) # 只取前10个
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in query.dicts():
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=row["model_name"],
|
||||
request_count=row["request_count"],
|
||||
total_cost=row["total_cost"],
|
||||
total_tokens=row["total_tokens"],
|
||||
avg_response_time=row["avg_response_time"] or 0.0,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
# 转换为字典以快速查找
|
||||
data_dict = {row["hour"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
# 转换为字典
|
||||
data_dict = {row["day"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": record.status,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计摘要失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
stats = await _get_model_statistics(start_time)
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
@@ -1,309 +0,0 @@
|
||||
"""
|
||||
WebUI Token 管理模块
|
||||
负责生成、保存、验证和更新访问令牌
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Token 管理器"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
if config_path is None:
|
||||
# 获取项目根目录 (src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
def _ensure_config(self):
|
||||
"""确保配置文件存在且包含有效的 token"""
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
|
||||
self._create_new_token()
|
||||
else:
|
||||
# 验证配置文件格式
|
||||
try:
|
||||
config = self._load_config()
|
||||
if not config.get("access_token"):
|
||||
logger.warning("WebUI 配置文件中缺少 access_token,正在重新生成")
|
||||
self._create_new_token()
|
||||
else:
|
||||
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||
self._create_new_token()
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: dict):
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"WebUI 配置已保存到: {self.config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WebUI 配置失败: {e}")
|
||||
raise
|
||||
|
||||
def _create_new_token(self) -> str:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""获取当前有效的 token"""
|
||||
config = self._load_config()
|
||||
return config.get("access_token", "")
|
||||
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
# 验证新 token 格式
|
||||
is_valid, error_msg = self._validate_custom_token(new_token)
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
return False, f"更新失败: {str(e)}"
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token(保留 first_setup_completed 状态)
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
def is_first_setup(self) -> bool:
|
||||
"""
|
||||
检查是否为首次配置
|
||||
|
||||
Returns:
|
||||
bool: 是否为首次配置
|
||||
"""
|
||||
config = self._load_config()
|
||||
return not config.get("first_setup_completed", False)
|
||||
|
||||
def mark_setup_completed(self) -> bool:
|
||||
"""
|
||||
标记首次配置已完成
|
||||
|
||||
Returns:
|
||||
bool: 是否标记成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = True
|
||||
config["setup_completed_at"] = self._get_current_timestamp()
|
||||
self._save_config(config)
|
||||
logger.info("首次配置已标记为完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"标记首次配置完成失败: {e}")
|
||||
return False
|
||||
|
||||
def reset_setup_status(self) -> bool:
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Returns:
|
||||
bool: 是否重置成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = False
|
||||
if "setup_completed_at" in config:
|
||||
del config["setup_completed_at"]
|
||||
self._save_config(config)
|
||||
logger.info("首次配置状态已重置")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重置首次配置状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_token_manager_instance: Optional[TokenManager] = None
|
||||
|
||||
|
||||
def get_token_manager() -> TokenManager:
|
||||
"""获取 TokenManager 单例"""
|
||||
global _token_manager_instance
|
||||
if _token_manager_instance is None:
|
||||
_token_manager_instance = TokenManager()
|
||||
return _token_manager_instance
|
||||
@@ -1,295 +0,0 @@
|
||||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui_server")
|
||||
|
||||
|
||||
class WebUIServer:
|
||||
"""独立的 WebUI 服务器"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
# 配置防爬虫中间件(需要在CORS之前注册)
|
||||
self._setup_anti_crawler()
|
||||
|
||||
# 配置 CORS(支持开发环境跨域请求)
|
||||
self._setup_cors()
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
# 注册robots.txt路由
|
||||
self._setup_robots_txt()
|
||||
|
||||
def _setup_cors(self):
|
||||
"""配置 CORS 中间件"""
|
||||
# 开发环境需要允许前端开发服务器的跨域请求
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:7999", # 前端开发服务器备用端口
|
||||
"http://127.0.0.1:7999",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
], # 明确指定允许的头
|
||||
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
def _show_access_token(self):
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取 Access Token 失败: {e}")
|
||||
|
||||
def _setup_static_files(self):
|
||||
"""设置静态文件服务"""
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return
|
||||
|
||||
# 处理 SPA 路由 - 注意:这个路由优先级最低
|
||||
@self.app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用 - 只处理非 API 请求"""
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
response = FileResponse(file_path, media_type=media_type)
|
||||
# HTML 文件添加防索引头
|
||||
if str(file_path).endswith(".html"):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 其他路径返回 index.html(SPA 路由)
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
def _setup_anti_crawler(self):
|
||||
"""配置防爬虫中间件"""
|
||||
try:
|
||||
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||
from src.config.config import global_config
|
||||
|
||||
# 从配置读取防爬虫模式
|
||||
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
||||
|
||||
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||
|
||||
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
|
||||
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
|
||||
|
||||
def _setup_robots_txt(self):
|
||||
"""设置robots.txt路由"""
|
||||
try:
|
||||
from src.webui.anti_crawler import create_robots_txt_response
|
||||
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回robots.txt,禁止所有爬虫"""
|
||||
return create_robots_txt_response()
|
||||
|
||||
logger.debug("✅ robots.txt 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||
|
||||
def _register_api_routes(self):
|
||||
"""注册所有 WebUI API 路由"""
|
||||
try:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
# 导入规划器监控路由
|
||||
from src.webui.api.planner import router as planner_router
|
||||
|
||||
# 导入回复器监控路由
|
||||
from src.webui.api.replier import router as replier_router
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
self.app.include_router(knowledge_router)
|
||||
self.app.include_router(chat_router)
|
||||
self.app.include_router(planner_router)
|
||||
self.app.include_router(replier_router)
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
# 预先检查端口是否可用
|
||||
if not self._check_port_available():
|
||||
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
|
||||
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
|
||||
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
|
||||
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
|
||||
# 根据地址类型显示正确的访问地址
|
||||
if ':' in self.host:
|
||||
# IPv6 地址需要用方括号包裹
|
||||
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
|
||||
if self.host == "::":
|
||||
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
|
||||
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
|
||||
elif self.host == "::1":
|
||||
logger.info("💡 仅支持 IPv6 本地访问")
|
||||
else:
|
||||
# IPv4 地址
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
except OSError as e:
|
||||
# 处理端口绑定相关的错误
|
||||
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
|
||||
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
else:
|
||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _check_port_available(self) -> bool:
|
||||
"""检查端口是否可用(支持 IPv4 和 IPv6)"""
|
||||
import socket
|
||||
|
||||
# 判断使用 IPv4 还是 IPv6
|
||||
if ':' in self.host:
|
||||
# IPv6 地址
|
||||
family = socket.AF_INET6
|
||||
test_host = self.host if self.host != "::" else "::1"
|
||||
else:
|
||||
# IPv4 地址
|
||||
family = socket.AF_INET
|
||||
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
|
||||
|
||||
try:
|
||||
with socket.socket(family, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
# 尝试绑定端口
|
||||
s.bind((test_host, self.port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
if self._server:
|
||||
logger.info("正在关闭 WebUI 服务器...")
|
||||
self._server.should_exit = True
|
||||
try:
|
||||
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||
logger.info("✅ WebUI 服务器已关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ WebUI 服务器关闭超时")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
|
||||
# 全局 WebUI 服务器实例
|
||||
_webui_server = None
|
||||
|
||||
|
||||
def get_webui_server() -> WebUIServer:
|
||||
"""获取全局 WebUI 服务器实例"""
|
||||
global _webui_server
|
||||
if _webui_server is None:
|
||||
# 从环境变量读取
|
||||
import os
|
||||
host = os.getenv("WEBUI_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||
_webui_server = WebUIServer(host=host, port=port)
|
||||
return _webui_server
|
||||
@@ -1,114 +0,0 @@
|
||||
"""WebSocket 认证模块
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.ws_auth")
|
||||
router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
def _cleanup_expired_ws_tokens():
|
||||
"""清理过期的临时 token"""
|
||||
now = time.time()
|
||||
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||
for t in expired:
|
||||
del _ws_temp_tokens[t]
|
||||
|
||||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
temp_token = secrets.token_urlsafe(32)
|
||||
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||
return temp_token
|
||||
|
||||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
if temp_token not in _ws_temp_tokens:
|
||||
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||
return False
|
||||
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||
if time.time() > expire_time:
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||
return False
|
||||
# 验证原始 session token 仍然有效
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||
return False
|
||||
# 消费 token(一次性使用)
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
|
||||
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||
# 这在登录页面是正常情况,不应该触发错误处理
|
||||
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
# 同样返回 200 但 success=False,避免前端刷新
|
||||
logger.debug("ws-token 请求:认证已过期")
|
||||
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
Reference in New Issue
Block a user