feat: Enhance authentication mechanism to support token retrieval from both Cookie and Header
- Added a new auth module to manage authentication-related functions. - Updated existing routes in expression_routes, person_routes, plugin_routes, and routes to utilize the new authentication methods. - Implemented CORS middleware in webui_server for development environment support. - Introduced functions to set and clear authentication cookies. - Enhanced token verification to prioritize Cookie over Header for improved security and flexibility.
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Annotated
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
import os
|
||||
import hashlib
|
||||
@@ -101,18 +102,12 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
@@ -144,6 +139,7 @@ async def get_emoji_list(
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
|
||||
sort_order: Optional[str] = Query("desc", description="排序方向"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -164,7 +160,7 @@ async def get_emoji_list(
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
@@ -222,7 +218,7 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
@@ -234,7 +230,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -251,7 +247,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
@@ -264,7 +260,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -303,7 +299,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
@@ -315,7 +311,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -340,7 +336,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
@@ -351,7 +347,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
@@ -395,7 +391,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
@@ -407,7 +403,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -435,7 +431,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
@@ -447,7 +443,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -474,6 +470,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -481,21 +478,31 @@ async def get_emoji_thumbnail(
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
token: 访问令牌(通过 query parameter,用于向后兼容)
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
"""
|
||||
try:
|
||||
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||
if token:
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
|
||||
# 1. 优先使用 Cookie
|
||||
if maibot_session and token_manager.verify_token(maibot_session):
|
||||
is_valid = True
|
||||
# 2. 其次使用 query parameter(用于向后兼容 img 标签)
|
||||
elif token and token_manager.verify_token(token):
|
||||
is_valid = True
|
||||
# 3. 最后使用 Authorization header
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
auth_token = authorization.replace("Bearer ", "")
|
||||
if token_manager.verify_token(auth_token):
|
||||
is_valid = True
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -528,7 +535,7 @@ async def get_emoji_thumbnail(
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
@@ -540,7 +547,7 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.emoji_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||
@@ -601,6 +608,7 @@ async def upload_emoji(
|
||||
description: DescriptionForm = "",
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -617,7 +625,7 @@ async def upload_emoji(
|
||||
上传结果和表情包信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 验证文件类型
|
||||
if not file.content_type:
|
||||
@@ -721,6 +729,7 @@ async def batch_upload_emoji(
|
||||
files: EmojiFiles,
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -736,7 +745,7 @@ async def batch_upload_emoji(
|
||||
批量上传结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
results = {
|
||||
"success": True,
|
||||
|
||||
Reference in New Issue
Block a user