feat: Enhance authentication mechanism to support token retrieval from both Cookie and Header

- Added a new auth module to manage authentication-related functions.
- Updated existing routes in expression_routes, person_routes, plugin_routes, and routes to utilize the new authentication methods.
- Implemented CORS middleware in webui_server for development environment support.
- Introduced functions to set and clear authentication cookies.
- Enhanced token verification to prioritize Cookie over Header for improved security and flexibility.
This commit is contained in:
墨梓柒
2025-11-30 15:53:39 +08:00
parent fdc0a87c31
commit c790dcb705
7 changed files with 429 additions and 148 deletions

View File

@@ -1,12 +1,13 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List, Annotated
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from .token_manager import get_token_manager
from .auth import verify_auth_token_from_cookie_or_header
import time
import os
import hashlib
@@ -101,18 +102,12 @@ class BatchDeleteResponse(BaseModel):
failed_ids: List[int] = []
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
@@ -144,6 +139,7 @@ async def get_emoji_list(
format: Optional[str] = Query(None, description="格式筛选"),
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
sort_order: Optional[str] = Query("desc", description="排序方向"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
@@ -164,7 +160,7 @@ async def get_emoji_list(
表情包列表
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Emoji.select()
@@ -222,7 +218,7 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
async def get_emoji_detail(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取表情包详细信息
@@ -234,7 +230,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
表情包详细信息
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -251,7 +247,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
增量更新表情包(只更新提供的字段)
@@ -264,7 +260,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
更新结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -303,7 +299,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
async def delete_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
删除表情包
@@ -315,7 +311,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
删除结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -340,7 +336,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
@router.get("/stats/summary")
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取表情包统计数据
@@ -351,7 +347,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
统计数据
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
@@ -395,7 +391,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
async def register_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
注册表情包(快捷操作)
@@ -407,7 +403,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
更新结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -435,7 +431,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
async def ban_emoji(emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
禁用表情包(快捷操作)
@@ -447,7 +443,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
更新结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -474,6 +470,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
async def get_emoji_thumbnail(
emoji_id: int,
token: Optional[str] = Query(None, description="访问令牌"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
@@ -481,21 +478,31 @@ async def get_emoji_thumbnail(
Args:
emoji_id: 表情包ID
token: 访问令牌(通过 query parameter
token: 访问令牌(通过 query parameter,用于向后兼容
maibot_session: Cookie 中的 token
authorization: Authorization header
Returns:
表情包图片文件
"""
try:
# 优先使用 query parameter 中的 token用于 img 标签)
if token:
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
else:
# 如果没有 query token则验证 Authorization header
verify_auth_token(authorization)
token_manager = get_token_manager()
is_valid = False
# 1. 优先使用 Cookie
if maibot_session and token_manager.verify_token(maibot_session):
is_valid = True
# 2. 其次使用 query parameter用于向后兼容 img 标签)
elif token and token_manager.verify_token(token):
is_valid = True
# 3. 最后使用 Authorization header
elif authorization and authorization.startswith("Bearer "):
auth_token = authorization.replace("Bearer ", "")
if token_manager.verify_token(auth_token):
is_valid = True
if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期")
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
@@ -528,7 +535,7 @@ async def get_emoji_thumbnail(
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
async def batch_delete_emojis(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
批量删除表情包
@@ -540,7 +547,7 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
批量删除结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
if not request.emoji_ids:
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
@@ -601,6 +608,7 @@ async def upload_emoji(
description: DescriptionForm = "",
emotion: EmotionForm = "",
is_registered: IsRegisteredForm = True,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
@@ -617,7 +625,7 @@ async def upload_emoji(
上传结果和表情包信息
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
# 验证文件类型
if not file.content_type:
@@ -721,6 +729,7 @@ async def batch_upload_emoji(
files: EmojiFiles,
emotion: EmotionForm = "",
is_registered: IsRegisteredForm = True,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
@@ -736,7 +745,7 @@ async def batch_upload_emoji(
批量上传结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
results = {
"success": True,