Files
mai-bot/src/webui/routers/expression.py
DrSmoothl 7255cc5602 fix(webui): remove references to deleted Expression fields
- 移除 ExpressionUpdateRequest 中的 checked/rejected/require_unchecked 字段
- 移除更新逻辑中的 setattr
- 添加 chat_id → session_id 映射
2026-02-17 19:58:29 +08:00

822 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_stream import get_chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
last_active_time: float
chat_id: str
create_date: Optional[float]
checked: bool
rejected: bool
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return chat_id
if chat_stream.group_info and chat_stream.group_info.group_name:
return chat_stream.group_info.group_name
if chat_stream.user_info and chat_stream.user_info.user_nickname:
return chat_stream.user_info.user_nickname
return chat_id
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_manager = get_chat_manager()
for chat_id in chat_ids:
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
continue
if chat_stream.group_info and chat_stream.group_info.group_name:
result[chat_id] = chat_stream.group_info.group_name
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
result[chat_id] = chat_stream.user_info.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表(用于下拉选择)
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(maibot_session, authorization)
chat_list = []
for stream_id, stream in get_chat_manager().streams.items():
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
if not chat_name and stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname
chat_name = chat_name or stream_id
chat_list.append(
ChatInfo(
chat_id=stream_id,
chat_name=chat_name,
platform=stream.platform,
is_group=bool(stream.group_info and stream.group_info.group_id),
)
)
# 按名称排序
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 (从 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
statement = select(Expression)
# 搜索过滤
if search:
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
statement = statement.order_by(
case((col(Expression.last_active_time).is_(None), 1), else_=0),
col(Expression.last_active_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(maibot_session, authorization)
current_time = datetime.now()
# 创建表达方式
with get_db_session() as session:
expression = Expression(
situation=request.situation,
style=request.style,
context="",
up_content="",
content_list="[]",
count=0,
last_active_time=current_time,
create_time=current_time,
session_id=request.chat_id,
)
session.add(expression)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式(只更新提供的字段)
Args:
expression_id: 表达方式ID
request: 更新请求(只包含需要更新的字段)
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 映射 API 字段名到数据库字段名
if "chat_id" in update_data:
update_data["session_id"] = update_data.pop("chat_id")
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data["last_active_time"] = datetime.now()
# 执行更新
with get_db_session() as session:
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
if not db_expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
for field, value in update_data.items():
if hasattr(db_expression, field):
setattr(db_expression, field, value)
session.add(db_expression)
expression = db_expression
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
with get_db_session() as session:
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = [expr_id for expr_id in session.exec(statements).all()]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
with get_db_session() as session:
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
chat_stats = {}
for chat_id in session.exec(select(Expression.session_id)).all():
if chat_id:
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
seven_days_ago = datetime.now() - timedelta(days=7)
recent_statement = (
select(func.count())
.select_from(Expression)
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
)
recent = session.exec(recent_statement).one()
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
rejected: int
ai_checked: int
user_checked: int
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取审核统计数据
Returns:
审核统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
return ReviewStatsResponse(
total=total,
unchecked=unchecked,
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
@router.get("/review/list", response_model=ReviewListResponse)
async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取待审核/已审核的表达方式列表
Args:
page: 页码
page_size: 每页数量
filter_type: 筛选类型 (unchecked/passed/rejected/all)
search: 搜索关键词
chat_id: 聊天ID筛选
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
# all 不需要额外过滤
# 搜索过滤
if search:
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
return ReviewListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=[expression_to_response(expr) for expr in expressions],
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
failed: int
results: List[BatchReviewResultItem]
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量审核表达方式
Args:
request: 批量审核请求
Returns:
批量审核结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
results = []
succeeded = 0
failed = 0
for item in request.items:
try:
with get_db_session() as session:
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
if not expression:
results.append(
BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式")
)
failed += 1
continue
# 冲突检测
if item.require_unchecked:
results.append(
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
)
failed += 1
continue
# 更新状态
with get_db_session() as session:
db_expression = session.exec(
select(Expression).where(col(Expression.id) == item.id).limit(1)
).first()
if not db_expression:
results.append(
BatchReviewResultItem(
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
)
)
failed += 1
continue
db_expression.last_active_time = datetime.now()
session.add(db_expression)
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
)
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e)))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量审核失败: {e}")
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e