feat: Enhance authentication mechanism to support token retrieval from both Cookie and Header

- Added a new auth module to manage authentication-related functions.
- Updated existing routes in expression_routes, person_routes, plugin_routes, and routes to utilize the new authentication methods.
- Implemented CORS middleware in webui_server for development environment support.
- Introduced functions to set and clear authentication cookies.
- Enhanced token verification to prioritize Cookie over Header for improved security and flexibility.
This commit is contained in:
墨梓柒
2025-11-30 15:53:39 +08:00
parent fdc0a87c31
commit c790dcb705
7 changed files with 429 additions and 148 deletions

View File

@@ -1,11 +1,12 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from .token_manager import get_token_manager
from .auth import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
@@ -87,18 +88,12 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
@@ -162,7 +157,7 @@ class ChatListResponse(BaseModel):
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(authorization: Optional[str] = Header(None)):
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表(用于下拉选择)
@@ -173,7 +168,7 @@ async def get_chat_list(authorization: Optional[str] = Header(None)):
聊天列表
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
@@ -205,6 +200,7 @@ async def get_expression_list(
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
@@ -221,7 +217,7 @@ async def get_expression_list(
表达方式列表
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
@@ -265,7 +261,7 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取表达方式详细信息
@@ -277,7 +273,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
表达方式详细信息
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
@@ -294,7 +290,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
创建新的表达方式
@@ -306,7 +302,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
创建结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
current_time = time.time()
@@ -336,7 +332,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式(只更新提供的字段)
@@ -350,7 +346,7 @@ async def update_expression(
更新结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
@@ -386,7 +382,7 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
删除表达方式
@@ -398,7 +394,7 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H
删除结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
@@ -429,7 +425,7 @@ class BatchDeleteRequest(BaseModel):
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
批量删除表达方式
@@ -441,7 +437,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
删除结果
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
@@ -470,7 +466,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
@router.get("/stats/summary")
async def get_expression_stats(authorization: Optional[str] = Header(None)):
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取表达方式统计数据
@@ -481,7 +477,7 @@ async def get_expression_stats(authorization: Optional[str] = Header(None)):
统计数据
"""
try:
verify_auth_token(authorization)
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()