Ruff fix
This commit is contained in:
@@ -51,7 +51,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
@@ -59,7 +59,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
@@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
@@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
class EmojiListResponse(BaseModel):
|
||||
"""表情包列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel):
|
||||
|
||||
class EmojiDetailResponse(BaseModel):
|
||||
"""表情包详情响应"""
|
||||
|
||||
success: bool
|
||||
data: EmojiResponse
|
||||
|
||||
|
||||
class EmojiUpdateRequest(BaseModel):
|
||||
"""表情包更新请求"""
|
||||
|
||||
description: Optional[str] = None
|
||||
is_registered: Optional[bool] = None
|
||||
is_banned: Optional[bool] = None
|
||||
@@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel):
|
||||
|
||||
class EmojiUpdateResponse(BaseModel):
|
||||
"""表情包更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
@@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel):
|
||||
|
||||
class EmojiDeleteResponse(BaseModel):
|
||||
"""表情包删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -73,13 +80,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -120,11 +127,11 @@ async def get_emoji_list(
|
||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
@@ -133,61 +140,51 @@ async def get_emoji_list(
|
||||
is_banned: 是否被禁用筛选
|
||||
format: 格式筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Emoji.description.contains(search)) |
|
||||
(Emoji.emoji_hash.contains(search))
|
||||
)
|
||||
|
||||
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||
|
||||
# 注册状态过滤
|
||||
if is_registered is not None:
|
||||
query = query.where(Emoji.is_registered == is_registered)
|
||||
|
||||
|
||||
# 禁用状态过滤
|
||||
if is_banned is not None:
|
||||
query = query.where(Emoji.is_banned == is_banned)
|
||||
|
||||
|
||||
# 格式过滤
|
||||
if format:
|
||||
query = query.where(Emoji.format == format)
|
||||
|
||||
|
||||
# 排序:使用次数倒序,然后按记录时间倒序
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Emoji.usage_count.desc(),
|
||||
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
|
||||
Emoji.record_time.desc()
|
||||
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
|
||||
)
|
||||
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
emojis = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
|
||||
return EmojiListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -196,33 +193,27 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
return EmojiDetailResponse(
|
||||
success=True,
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -231,61 +222,55 @@ async def get_emoji_detail(
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
|
||||
# 处理情感标签(转换为 JSON)
|
||||
if 'emotion' in update_data:
|
||||
if update_data['emotion'] is None:
|
||||
update_data['emotion'] = None
|
||||
if "emotion" in update_data:
|
||||
if update_data["emotion"] is None:
|
||||
update_data["emotion"] = None
|
||||
else:
|
||||
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
|
||||
|
||||
update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False)
|
||||
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
|
||||
update_data['register_time'] = time.time()
|
||||
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
|
||||
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=emoji_to_response(emoji)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -294,41 +279,35 @@ async def update_emoji(
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
emoji_hash = emoji.emoji_hash
|
||||
|
||||
|
||||
# 执行删除
|
||||
emoji.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表情包: {emoji_hash}"
|
||||
)
|
||||
|
||||
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -337,31 +316,29 @@ async def delete_emoji(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
banned = Emoji.select().where(Emoji.is_banned).count()
|
||||
|
||||
|
||||
# 按格式统计
|
||||
formats = {}
|
||||
for emoji in Emoji.select(Emoji.format):
|
||||
fmt = emoji.format
|
||||
formats[fmt] = formats.get(fmt, 0) + 1
|
||||
|
||||
|
||||
# 获取最常用的表情包(前10)
|
||||
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
|
||||
top_used_list = [
|
||||
@@ -369,11 +346,11 @@ async def get_emoji_stats(
|
||||
"id": emoji.id,
|
||||
"emoji_hash": emoji.emoji_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.usage_count
|
||||
"usage_count": emoji.usage_count,
|
||||
}
|
||||
for emoji in top_used
|
||||
]
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
@@ -382,10 +359,10 @@ async def get_emoji_stats(
|
||||
"banned": banned,
|
||||
"unregistered": total - registered,
|
||||
"formats": formats,
|
||||
"top_used": top_used_list
|
||||
}
|
||||
"top_used": top_used_list,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -394,47 +371,40 @@ async def get_emoji_stats(
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
|
||||
|
||||
if emoji.is_banned:
|
||||
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
|
||||
|
||||
|
||||
# 注册表情包
|
||||
emoji.is_registered = True
|
||||
emoji.register_time = time.time()
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包注册成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -443,41 +413,34 @@ async def register_emoji(
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(
|
||||
emoji_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 禁用表情包(同时取消注册)
|
||||
emoji.is_banned = True
|
||||
emoji.is_registered = False
|
||||
emoji.save()
|
||||
|
||||
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True,
|
||||
message="表情包禁用成功",
|
||||
data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -489,16 +452,16 @@ async def ban_emoji(
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
"""
|
||||
@@ -511,37 +474,32 @@ async def get_emoji_thumbnail(
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
'png': 'image/png',
|
||||
'jpg': 'image/jpeg',
|
||||
'jpeg': 'image/jpeg',
|
||||
'gif': 'image/gif',
|
||||
'webp': 'image/webp',
|
||||
'bmp': 'image/bmp'
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
|
||||
|
||||
return FileResponse(
|
||||
path=emoji.full_path,
|
||||
media_type=media_type,
|
||||
filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表情包缩略图失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
@@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
||||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
|
||||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
@@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
|
||||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
@@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
@@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
|
||||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
@@ -82,13 +91,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -112,64 +121,58 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style, context)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search)) |
|
||||
(Expression.style.contains(search)) |
|
||||
(Expression.context.contains(search))
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
| (Expression.context.contains(search))
|
||||
)
|
||||
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
|
||||
Expression.last_active_time.desc()
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -178,33 +181,27 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(
|
||||
success=True,
|
||||
data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -213,25 +210,22 @@ async def get_expression_detail(
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
@@ -242,15 +236,13 @@ async def create_expression(
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True,
|
||||
message="表达方式创建成功",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -260,52 +252,48 @@ async def create_expression(
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data['last_active_time'] = time.time()
|
||||
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
|
||||
expression.save()
|
||||
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=expression_to_response(expression)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -314,41 +302,35 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除表达方式: {situation}"
|
||||
)
|
||||
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -357,46 +339,45 @@ async def delete_expression(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = Expression.select().where(
|
||||
(Expression.create_date.is_null(False)) &
|
||||
(Expression.create_date >= seven_days_ago)
|
||||
).count()
|
||||
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
|
||||
}
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
@@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
@@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
@@ -34,10 +37,10 @@ class MirrorType(str, Enum):
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
@@ -47,7 +50,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
@@ -56,7 +59,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
@@ -65,7 +68,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
@@ -74,7 +77,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
@@ -83,7 +86,7 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
@@ -92,23 +95,23 @@ class GitMirrorConfig:
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None
|
||||
}
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
@@ -122,59 +125,59 @@ class GitMirrorConfig:
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, 'r', encoding='utf-8') as f:
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, 'w', encoding='utf-8') as f:
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
@@ -182,26 +185,26 @@ class GitMirrorConfig:
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
@@ -209,15 +212,15 @@ class GitMirrorConfig:
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat()
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
@@ -225,11 +228,11 @@ class GitMirrorConfig:
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
@@ -245,19 +248,19 @@ class GitMirrorConfig:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
@@ -267,9 +270,9 @@ class GitMirrorConfig:
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
@@ -278,16 +281,11 @@ class GitMirrorConfig:
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
timeout: int = 30,
|
||||
config: Optional[GitMirrorConfig] = None
|
||||
):
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
@@ -297,16 +295,16 @@ class GitMirrorService:
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
@@ -316,54 +314,33 @@ class GitMirrorService:
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "系统中未找到 Git,请先安装 Git"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(
|
||||
["git", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {
|
||||
"installed": True,
|
||||
"version": version,
|
||||
"path": git_path
|
||||
}
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"Git 命令执行失败: {result.stderr}"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": "Git 版本检测超时"
|
||||
}
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {
|
||||
"installed": False,
|
||||
"error": f"检测 Git 时发生错误: {str(e)}"
|
||||
}
|
||||
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -371,11 +348,11 @@ class GitMirrorService:
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
@@ -383,7 +360,7 @@ class GitMirrorService:
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
@@ -393,29 +370,24 @@ class GitMirrorService:
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
@@ -427,15 +399,13 @@ class GitMirrorService:
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(
|
||||
owner, repo, branch, file_path, mirror
|
||||
)
|
||||
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
@@ -445,15 +415,15 @@ class GitMirrorService:
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
@@ -461,39 +431,29 @@ class GitMirrorService:
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror: Dict[str, Any]
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
@@ -501,14 +461,14 @@ class GitMirrorService:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
@@ -519,15 +479,9 @@ class GitMirrorService:
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -536,11 +490,11 @@ class GitMirrorService:
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
@@ -549,7 +503,7 @@ class GitMirrorService:
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
@@ -559,44 +513,32 @@ class GitMirrorService:
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"未找到镜像源: {mirror_id}",
|
||||
"mirror_used": None,
|
||||
"attempts": 0
|
||||
}
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(
|
||||
owner, repo, target_path, branch, depth, mirror
|
||||
)
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {
|
||||
"success": False,
|
||||
"error": "所有镜像源克隆均失败",
|
||||
"mirror_used": None,
|
||||
"attempts": len(mirrors_to_try)
|
||||
}
|
||||
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
@@ -604,52 +546,47 @@ class GitMirrorService:
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any]
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
|
||||
async def _clone_with_url(
|
||||
self,
|
||||
url: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror_type: str
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
@@ -657,24 +594,24 @@ class GitMirrorService:
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install"
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
||||
def run_git_clone():
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300 # 5分钟超时
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
@@ -683,40 +620,34 @@ class GitMirrorService:
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default"
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": last_error,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url
|
||||
}
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set
|
||||
import json
|
||||
@@ -14,30 +15,30 @@ active_connections: Set[WebSocket] = set()
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志条数
|
||||
|
||||
|
||||
Returns:
|
||||
日志列表
|
||||
"""
|
||||
logs = []
|
||||
log_dir = Path("logs")
|
||||
|
||||
|
||||
if not log_dir.exists():
|
||||
return logs
|
||||
|
||||
|
||||
# 获取所有日志文件,按修改时间排序
|
||||
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
|
||||
# 用于生成唯一 ID 的计数器
|
||||
log_counter = 0
|
||||
|
||||
|
||||
# 从最新的文件开始读取
|
||||
for log_file in log_files:
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
@@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
@@ -64,7 +67,7 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
except Exception as e:
|
||||
logger.error(f"读取日志文件失败 {log_file}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 反转列表,使其按时间顺序排列(旧到新)
|
||||
return list(reversed(logs))
|
||||
|
||||
@@ -72,35 +75,35 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
"""
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
recent_logs = load_recent_logs(limit=100)
|
||||
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
|
||||
|
||||
|
||||
for log_entry in recent_logs:
|
||||
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"发送历史日志失败: {e}")
|
||||
|
||||
|
||||
try:
|
||||
# 保持连接,等待客户端消息或断开
|
||||
while True:
|
||||
# 接收客户端消息(用于心跳或控制指令)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
|
||||
# 可以处理客户端的控制消息,例如:
|
||||
# - "ping" -> 心跳检测
|
||||
# - {"filter": "ERROR"} -> 设置日志级别过滤
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
@@ -111,19 +114,19 @@ async def websocket_logs(websocket: WebSocket):
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
@@ -131,7 +134,7 @@ async def broadcast_log(log_data: dict):
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
@@ -10,10 +11,10 @@ logger = get_logger("webui")
|
||||
def setup_webui(mode: str = "production") -> bool:
|
||||
"""
|
||||
设置 WebUI
|
||||
|
||||
|
||||
Args:
|
||||
mode: 运行模式,"development" 或 "production"
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功设置
|
||||
"""
|
||||
@@ -22,7 +23,7 @@ def setup_webui(mode: str = "production") -> bool:
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
|
||||
|
||||
if mode == "development":
|
||||
return setup_dev_mode()
|
||||
else:
|
||||
@@ -33,12 +34,12 @@ def setup_dev_mode() -> bool:
|
||||
"""设置开发模式 - 仅启用 CORS,前端自行启动"""
|
||||
from src.common.server import get_global_server
|
||||
from .logs_ws import router as logs_router
|
||||
|
||||
|
||||
# 注册 WebSocket 日志路由(开发模式也需要)
|
||||
server = get_global_server()
|
||||
server.register_router(logs_router)
|
||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
||||
|
||||
|
||||
logger.info("📝 WebUI 开发模式已启用")
|
||||
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
|
||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||
@@ -52,33 +53,33 @@ def setup_production_mode() -> bool:
|
||||
from starlette.responses import FileResponse
|
||||
from .logs_ws import router as logs_router
|
||||
import mimetypes
|
||||
|
||||
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript', '.js')
|
||||
mimetypes.add_type('application/javascript', '.mjs')
|
||||
mimetypes.add_type('text/css', '.css')
|
||||
mimetypes.add_type('application/json', '.json')
|
||||
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
server = get_global_server()
|
||||
|
||||
|
||||
# 注册 WebSocket 日志路由
|
||||
server.register_router(logs_router)
|
||||
logger.info("✅ WebSocket 日志推送路由已注册")
|
||||
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return False
|
||||
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return False
|
||||
|
||||
|
||||
# 处理 SPA 路由
|
||||
@server.app.get("/{full_path:path}")
|
||||
async def serve_spa(full_path: str):
|
||||
@@ -86,23 +87,23 @@ def setup_production_mode() -> bool:
|
||||
# API 路由不处理
|
||||
if full_path.startswith("api/"):
|
||||
return None
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
|
||||
# 返回 index.html(SPA 路由)
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port = os.getenv("PORT", "8000")
|
||||
logger.info("✅ WebUI 生产模式已挂载")
|
||||
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"挂载 WebUI 静态文件失败: {e}")
|
||||
return False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
@@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
|
||||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
"""人物信息响应"""
|
||||
|
||||
id: int
|
||||
is_known: bool
|
||||
person_id: str
|
||||
@@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
|
||||
|
||||
class PersonListResponse(BaseModel):
|
||||
"""人物列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
@@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
|
||||
|
||||
class PersonDetailResponse(BaseModel):
|
||||
"""人物详情响应"""
|
||||
|
||||
success: bool
|
||||
data: PersonInfoResponse
|
||||
|
||||
|
||||
class PersonUpdateRequest(BaseModel):
|
||||
"""人物信息更新请求"""
|
||||
|
||||
person_name: Optional[str] = None
|
||||
name_reason: Optional[str] = None
|
||||
nickname: Optional[str] = None
|
||||
@@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
|
||||
|
||||
class PersonUpdateResponse(BaseModel):
|
||||
"""人物信息更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[PersonInfoResponse] = None
|
||||
@@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel):
|
||||
|
||||
class PersonDeleteResponse(BaseModel):
|
||||
"""人物删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
@@ -72,13 +79,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -118,11 +125,11 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
authorization: Optional[str] = Header(None)
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
@@ -130,58 +137,50 @@ async def get_person_list(
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search)) |
|
||||
(PersonInfo.nickname.contains(search)) |
|
||||
(PersonInfo.user_id.contains(search))
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
|
||||
# 已认识状态过滤
|
||||
if is_known is not None:
|
||||
query = query.where(PersonInfo.is_known == is_known)
|
||||
|
||||
|
||||
# 平台过滤
|
||||
if platform:
|
||||
query = query.where(PersonInfo.platform == platform)
|
||||
|
||||
|
||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
query = query.order_by(
|
||||
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
|
||||
PersonInfo.last_know.desc()
|
||||
)
|
||||
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
persons = query.offset(offset).limit(page_size)
|
||||
|
||||
|
||||
# 转换为响应对象
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data
|
||||
)
|
||||
|
||||
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -190,33 +189,27 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(
|
||||
success=True,
|
||||
data=person_to_response(person)
|
||||
)
|
||||
|
||||
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -225,53 +218,47 @@ async def get_person_detail(
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
|
||||
# 更新最后修改时间
|
||||
update_data['last_know'] = time.time()
|
||||
|
||||
update_data["last_know"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(person, field, value)
|
||||
|
||||
|
||||
person.save()
|
||||
|
||||
|
||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {len(update_data)} 个字段",
|
||||
data=person_to_response(person)
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -280,41 +267,35 @@ async def update_person(
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.nickname or person.user_id
|
||||
|
||||
|
||||
# 执行删除
|
||||
person.delete_instance()
|
||||
|
||||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
|
||||
return PersonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除人物信息: {person_name}"
|
||||
)
|
||||
|
||||
|
||||
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -323,41 +304,31 @@ async def delete_person(
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
unknown = total - known
|
||||
|
||||
|
||||
# 按平台统计
|
||||
platforms = {}
|
||||
for person in PersonInfo.select(PersonInfo.platform):
|
||||
platform = person.platform
|
||||
platforms[platform] = platforms.get(platform, 0) + 1
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"known": known,
|
||||
"unknown": unknown,
|
||||
"platforms": platforms
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from typing import Set, Dict, Any
|
||||
import json
|
||||
@@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
|
||||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
@@ -30,20 +31,20 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
"""广播进度更新到所有连接的客户端"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected = set()
|
||||
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
|
||||
# 移除断开的连接
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
@@ -57,10 +58,10 @@ async def update_progress(
|
||||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
|
||||
Args:
|
||||
stage: 阶段 (idle, loading, success, error)
|
||||
progress: 进度百分比 (0-100)
|
||||
@@ -80,9 +81,9 @@ async def update_progress(
|
||||
"plugin_id": plugin_id,
|
||||
"total_plugins": total_plugins,
|
||||
"loaded_plugins": loaded_plugins,
|
||||
"timestamp": asyncio.get_event_loop().time()
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
|
||||
await broadcast_progress(progress_data)
|
||||
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
|
||||
|
||||
@@ -90,30 +91,30 @@ async def update_progress(
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
"""
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
|
||||
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
|
||||
|
||||
|
||||
# 保持连接并处理客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
|
||||
# 处理客户端心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
break
|
||||
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@
|
||||
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -19,12 +20,14 @@ _start_time = time.time()
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""状态响应"""
|
||||
|
||||
running: bool
|
||||
uptime: float
|
||||
version: str
|
||||
@@ -35,74 +38,60 @@ class StatusResponse(BaseModel):
|
||||
async def restart_maibot():
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
|
||||
使用 os.execv 重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
try:
|
||||
# 记录重启操作
|
||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||
|
||||
|
||||
# 使用 os.execv 重启当前进程
|
||||
# 这会替换当前进程,保持相同的 PID
|
||||
python = sys.executable
|
||||
args = [python] + sys.argv
|
||||
|
||||
|
||||
# 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启)
|
||||
# 但我们仍然返回它以保持 API 一致性
|
||||
os.execv(python, args)
|
||||
|
||||
return RestartResponse(
|
||||
success=True,
|
||||
message="麦麦正在重启中..."
|
||||
)
|
||||
|
||||
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"重启失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status():
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
|
||||
返回麦麦的运行状态、运行时长和版本信息。
|
||||
"""
|
||||
try:
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
|
||||
# 尝试获取版本信息(需要根据实际情况调整)
|
||||
version = MMC_VERSION # 可以从配置或常量中读取
|
||||
|
||||
|
||||
return StatusResponse(
|
||||
running=True,
|
||||
uptime=uptime,
|
||||
version=version,
|
||||
start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"获取状态失败: {str(e)}"
|
||||
) from e
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config():
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
|
||||
仅重新加载配置文件,某些配置可能需要重启才能生效。
|
||||
此功能需要在主程序中实现配置热重载逻辑。
|
||||
"""
|
||||
# 这里需要调用主程序的配置重载函数
|
||||
# 示例:await app_instance.reload_config()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置重载功能待实现"
|
||||
}
|
||||
|
||||
return {"success": True, "message": "配置重载功能待实现"}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
@@ -38,28 +39,33 @@ router.include_router(system_router)
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
@@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
|
||||
|
||||
class FirstSetupStatusResponse(BaseModel):
|
||||
"""首次配置状态响应"""
|
||||
|
||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||
message: str = Field(..., description="状态消息")
|
||||
|
||||
|
||||
class CompleteSetupResponse(BaseModel):
|
||||
"""完成配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
class ResetSetupResponse(BaseModel):
|
||||
"""重置配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
@@ -93,44 +102,35 @@ async def health_check():
|
||||
async def verify_token(request: TokenVerifyRequest):
|
||||
"""
|
||||
验证访问令牌
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(
|
||||
valid=True,
|
||||
message="Token 验证成功"
|
||||
)
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||
else:
|
||||
return TokenVerifyResponse(
|
||||
valid=False,
|
||||
message="Token 无效或已过期"
|
||||
)
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
@@ -138,20 +138,17 @@ async def update_token(
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
return TokenUpdateResponse(
|
||||
success=success,
|
||||
message=message
|
||||
)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -163,10 +160,10 @@ async def update_token(
|
||||
async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
@@ -174,21 +171,17 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
return TokenRegenerateResponse(
|
||||
success=True,
|
||||
token=new_token,
|
||||
message="Token 已重新生成"
|
||||
)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -200,10 +193,10 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
@@ -211,20 +204,17 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
|
||||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
return FirstSetupStatusResponse(
|
||||
is_first_setup=is_first,
|
||||
message="首次配置" if is_first else "已完成配置"
|
||||
)
|
||||
|
||||
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -236,10 +226,10 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
@@ -247,20 +237,17 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
|
||||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
return CompleteSetupResponse(
|
||||
success=success,
|
||||
message="配置已完成" if success else "标记失败"
|
||||
)
|
||||
|
||||
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -272,10 +259,10 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
|
||||
Args:
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
@@ -283,20 +270,17 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
|
||||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
return ResetSetupResponse(
|
||||
success=success,
|
||||
message="配置状态已重置" if success else "重置失败"
|
||||
)
|
||||
|
||||
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List
|
||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
@@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
@@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
@@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
@@ -56,39 +61,39 @@ class DashboardData(BaseModel):
|
||||
async def get_dashboard_data(hours: int = 24):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
@@ -98,100 +103,84 @@ async def get_dashboard_data(hours: int = 24):
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据"""
|
||||
summary = StatisticsSummary()
|
||||
|
||||
|
||||
# 查询 LLM 使用记录
|
||||
llm_records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
|
||||
llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
total_time_cost = 0.0
|
||||
time_cost_count = 0
|
||||
|
||||
|
||||
for record in llm_records:
|
||||
summary.total_requests += 1
|
||||
summary.total_cost += record.cost or 0.0
|
||||
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
|
||||
if record.time_cost and record.time_cost > 0:
|
||||
total_time_cost += record.time_cost
|
||||
time_cost_count += 1
|
||||
|
||||
|
||||
# 计算平均响应时间
|
||||
if time_cost_count > 0:
|
||||
summary.avg_response_time = total_time_cost / time_cost_count
|
||||
|
||||
|
||||
# 查询在线时间
|
||||
online_records = list(
|
||||
OnlineTime.select()
|
||||
.where(
|
||||
(OnlineTime.start_timestamp >= start_time) |
|
||||
(OnlineTime.end_timestamp >= start_time)
|
||||
)
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||
)
|
||||
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
|
||||
# 查询消息数量
|
||||
messages = list(
|
||||
Messages.select()
|
||||
.where(Messages.time >= start_time.timestamp())
|
||||
.where(Messages.time <= end_time.timestamp())
|
||||
Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp())
|
||||
)
|
||||
|
||||
|
||||
summary.total_messages = len(messages)
|
||||
# 简单统计:如果 reply_to 不为空,则认为是回复
|
||||
summary.total_replies = len([m for m in messages if m.reply_to])
|
||||
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据"""
|
||||
model_data = defaultdict(lambda: {
|
||||
'request_count': 0,
|
||||
'total_cost': 0.0,
|
||||
'total_tokens': 0,
|
||||
'time_costs': []
|
||||
})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
)
|
||||
|
||||
model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []})
|
||||
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time))
|
||||
|
||||
for record in records:
|
||||
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||
model_data[model_name]['request_count'] += 1
|
||||
model_data[model_name]['total_cost'] += record.cost or 0.0
|
||||
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
model_data[model_name]["request_count"] += 1
|
||||
model_data[model_name]["total_cost"] += record.cost or 0.0
|
||||
model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
if record.time_cost and record.time_cost > 0:
|
||||
model_data[model_name]['time_costs'].append(record.time_cost)
|
||||
|
||||
model_data[model_name]["time_costs"].append(record.time_cost)
|
||||
|
||||
# 转换为列表并排序
|
||||
result = []
|
||||
for model_name, data in model_data.items():
|
||||
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
|
||||
result.append(ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=data['request_count'],
|
||||
total_cost=data['total_cost'],
|
||||
total_tokens=data['total_tokens'],
|
||||
avg_response_time=avg_time
|
||||
))
|
||||
|
||||
avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=model_name,
|
||||
request_count=data["request_count"],
|
||||
total_cost=data["total_cost"],
|
||||
total_tokens=data["total_tokens"],
|
||||
avg_response_time=avg_time,
|
||||
)
|
||||
)
|
||||
|
||||
# 按请求数排序
|
||||
result.sort(key=lambda x: x.request_count, reverse=True)
|
||||
return result[:10] # 返回前10个
|
||||
@@ -200,96 +189,80 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据"""
|
||||
# 创建小时桶
|
||||
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
|
||||
hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
for record in records:
|
||||
# 获取小时键(去掉分钟和秒)
|
||||
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
|
||||
hour_str = hour_key.isoformat()
|
||||
|
||||
hourly_buckets[hour_str]['requests'] += 1
|
||||
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
|
||||
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
|
||||
hourly_buckets[hour_str]["requests"] += 1
|
||||
hourly_buckets[hour_str]["cost"] += record.cost or 0.0
|
||||
hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.isoformat()
|
||||
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=hour_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||
)
|
||||
current += timedelta(hours=1)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据"""
|
||||
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.where(LLMUsage.timestamp <= end_time)
|
||||
)
|
||||
|
||||
daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
|
||||
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
|
||||
|
||||
for record in records:
|
||||
# 获取日期键
|
||||
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
day_str = day_key.isoformat()
|
||||
|
||||
daily_buckets[day_str]['requests'] += 1
|
||||
daily_buckets[day_str]['cost'] += record.cost or 0.0
|
||||
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
|
||||
daily_buckets[day_str]["requests"] += 1
|
||||
daily_buckets[day_str]["cost"] += record.cost or 0.0
|
||||
daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.isoformat()
|
||||
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
|
||||
result.append(TimeSeriesData(
|
||||
timestamp=day_str,
|
||||
requests=data['requests'],
|
||||
cost=data['cost'],
|
||||
tokens=data['tokens']
|
||||
))
|
||||
data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0})
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
|
||||
)
|
||||
current += timedelta(days=1)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
records = list(
|
||||
LLMUsage.select()
|
||||
.order_by(LLMUsage.timestamp.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append({
|
||||
'timestamp': record.timestamp.isoformat(),
|
||||
'model': record.model_assign_name or record.model_name,
|
||||
'request_type': record.request_type,
|
||||
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
'cost': record.cost or 0.0,
|
||||
'time_cost': record.time_cost or 0.0,
|
||||
'status': record.status
|
||||
})
|
||||
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": record.status,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@@ -297,7 +270,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
async def get_summary(hours: int = 24):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
@@ -315,7 +288,7 @@ async def get_summary(hours: int = 24):
|
||||
async def get_model_stats(hours: int = 24):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ class TokenManager:
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
@@ -27,10 +27,10 @@ class TokenManager:
|
||||
# 获取项目根目录 (src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
@@ -75,22 +75,23 @@ class TokenManager:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False # 标记首次配置未完成
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
@@ -101,38 +102,38 @@ class TokenManager:
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
@@ -141,17 +142,17 @@ class TokenManager:
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
@@ -160,7 +161,7 @@ class TokenManager:
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token
|
||||
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
@@ -170,20 +171,20 @@ class TokenManager:
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
@@ -194,48 +195,48 @@ class TokenManager:
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
def is_first_setup(self) -> bool:
|
||||
"""
|
||||
检查是否为首次配置
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否为首次配置
|
||||
"""
|
||||
@@ -245,7 +246,7 @@ class TokenManager:
|
||||
def mark_setup_completed(self) -> bool:
|
||||
"""
|
||||
标记首次配置已完成
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否标记成功
|
||||
"""
|
||||
@@ -263,7 +264,7 @@ class TokenManager:
|
||||
def reset_setup_status(self) -> bool:
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否重置成功
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user