feat: Enhance authentication mechanism to support token retrieval from both Cookie and Header
- Added a new auth module to manage authentication-related functions. - Updated existing routes in expression_routes, person_routes, plugin_routes, and routes to utilize the new authentication methods. - Implemented CORS middleware in webui_server for development environment support. - Introduced functions to set and clear authentication cookies. - Enhanced token verification to prioritize Cookie over Header for improved security and flexibility.
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
@@ -87,18 +88,12 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
@@ -162,7 +157,7 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
@@ -173,7 +168,7 @@ async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
@@ -205,6 +200,7 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -221,7 +217,7 @@ async def get_expression_list(
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
@@ -265,7 +261,7 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_detail(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
@@ -277,7 +273,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -294,7 +290,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def create_expression(request: ExpressionCreateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
@@ -306,7 +302,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
@@ -336,7 +332,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
expression_id: int, request: ExpressionUpdateRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -350,7 +346,7 @@ async def update_expression(
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -386,7 +382,7 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_expression(expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
@@ -398,7 +394,7 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -429,7 +425,7 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
@@ -441,7 +437,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
@@ -470,7 +466,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
@@ -481,7 +477,7 @@ async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user