Files
mai-bot/src/webui/routers/expression.py

843 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""表达方式管理 API 路由"""
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy import case, func
from sqlmodel import col, delete, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.expression")
EXCLUDE_IDS_QUERY = Query(None, description="需要排除的表达方式 ID")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
last_active_time: float
chat_id: str
chat_name: Optional[str] = None
create_date: Optional[float]
checked: bool
rejected: bool
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
"""从最近消息中解析聊天显示名称。"""
statement = (
select(Messages)
.where(col(Messages.session_id) == chat_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
message = db_session.exec(statement).first()
if not message:
return None
if message.group_id:
return message.group_name or f"群聊{message.group_id}"
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
"""从会话记录推断兜底显示名称。"""
if chat_session.group_id:
return f"群聊{chat_session.group_id}"
if chat_session.user_id:
return f"用户{chat_session.user_id}"
return chat_session.session_id
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
if name := _chat_manager.get_session_name(chat_id):
return name
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
return name
session = _chat_manager.get_session_by_session_id(chat_id)
if session:
if session.group_id:
return f"群聊{session.group_id}"
if session.user_id:
return f"用户{session.user_id}"
return chat_id
except Exception:
return chat_id
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
"""将表达方式模型转换为响应对象。
Args:
expression: 数据库中的表达方式记录。
Returns:
ExpressionResponse: WebUI 可直接序列化的响应对象。
"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
chat_id = expression.session_id or ""
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=chat_id,
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
create_date=create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by.value if expression.modified_by else None,
)
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称。
Args:
chat_ids: 需要查询的聊天会话 ID 列表。
Returns:
Dict[str, str]: 以聊天 ID 为键、显示名称为值的映射。
"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
for chat_id in chat_ids:
result[chat_id] = get_chat_name(chat_id)
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list() -> ChatListResponse:
"""获取所有聊天列表。
Returns:
ChatListResponse: 可用于下拉选择的聊天列表。
"""
try:
chat_by_id: Dict[str, ChatInfo] = {}
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
with get_db_session() as session:
for chat_session in session.exec(select(ChatSession)).all():
if chat_session.session_id in chat_by_id:
continue
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
chat_by_id[chat_session.session_id] = ChatInfo(
chat_id=chat_session.session_id,
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
expression_chat_ids = {
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
}
for session_id in expression_chat_ids:
if session_id in chat_by_id:
continue
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=get_chat_name(session_id, session),
platform=None,
is_group=False,
)
# 按名称排序
chat_list = list(chat_by_id.values())
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
) -> ExpressionListResponse:
"""获取表达方式列表。
Args:
page: 页码,从 1 开始。
page_size: 每页数量,范围为 1-100。
search: 搜索关键词,用于匹配情景和风格。
chat_id: 聊天 ID 筛选条件。
Returns:
ExpressionListResponse: 分页后的表达方式列表。
"""
try:
# 构建查询
statement = select(Expression)
# 搜索过滤
if search:
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
statement = statement.order_by(
case((col(Expression.last_active_time).is_(None), 1), else_=0),
col(Expression.last_active_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr, session) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
"""获取表达方式详细信息。
Args:
expression_id: 表达方式 ID。
Returns:
ExpressionDetailResponse: 指定表达方式的详细信息。
"""
try:
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
data = expression_to_response(expression, session)
return ExpressionDetailResponse(success=True, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
) -> ExpressionCreateResponse:
"""创建新的表达方式。
Args:
request: 创建表达方式所需的请求数据。
Returns:
ExpressionCreateResponse: 创建结果和新表达方式数据。
"""
try:
current_time = datetime.now()
# 创建表达方式
with get_db_session() as session:
expression = Expression(
situation=request.situation,
style=request.style,
content_list="[]",
count=0,
last_active_time=current_time,
create_time=current_time,
session_id=request.chat_id,
)
session.add(expression)
session.flush()
expression_id = expression.id
data = expression_to_response(expression, session)
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
return ExpressionCreateResponse(success=True, message="表达方式创建成功", data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
) -> ExpressionUpdateResponse:
"""增量更新表达方式。
Args:
expression_id: 表达方式 ID。
request: 只包含需要更新字段的请求数据。
Returns:
ExpressionUpdateResponse: 更新结果和更新后的表达方式数据。
"""
try:
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 映射 API 字段名到数据库字段名
if "chat_id" in update_data:
update_data["session_id"] = update_data.pop("chat_id")
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data["last_active_time"] = datetime.now()
# 执行更新
with get_db_session() as session:
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
if not db_expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
if "situation" in update_data:
db_expression.situation = update_data["situation"]
if "style" in update_data:
db_expression.style = update_data["style"]
if "session_id" in update_data:
db_expression.session_id = update_data["session_id"]
db_expression.last_active_time = update_data["last_active_time"]
session.add(db_expression)
data = expression_to_response(db_expression, session)
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(success=True, message=f"成功更新 {len(update_data)} 个字段", data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(expression_id: int) -> ExpressionDeleteResponse:
"""删除表达方式。
Args:
expression_id: 表达方式 ID。
Returns:
ExpressionDeleteResponse: 删除结果。
"""
try:
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
) -> ExpressionDeleteResponse:
"""批量删除表达方式。
Args:
request: 包含要删除表达方式 ID 列表的请求。
Returns:
ExpressionDeleteResponse: 批量删除结果。
"""
try:
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = list(session.exec(statements).all())
# 检查是否有未找到的ID
if not_found_ids := set(request.ids) - set(found_ids):
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
with get_db_session() as session:
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats() -> Dict[str, Any]:
"""获取表达方式统计数据。
Returns:
Dict[str, Any]: 表达方式数量、近期新增和聊天分布统计。
"""
try:
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
chat_stats = {}
for chat_id in session.exec(select(Expression.session_id)).all():
if chat_id:
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
seven_days_ago = datetime.now() - timedelta(days=7)
recent_statement = (
select(func.count())
.select_from(Expression)
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
)
recent = session.exec(recent_statement).one()
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
rejected: int
ai_checked: int
user_checked: int
def apply_review_filter(statement: Any, filter_type: str) -> Any:
"""按审核状态过滤表达方式查询。"""
if filter_type == "unchecked":
return statement.where(col(Expression.checked).is_(False))
if filter_type == "passed":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
if filter_type == "rejected":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
return statement
def count_expressions(session: Any, statement: Any) -> int:
"""统计表达方式查询结果数量。"""
return len(session.exec(statement).all())
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats() -> ReviewStatsResponse:
"""获取审核统计数据。
Returns:
ReviewStatsResponse: 审核统计数据。
"""
try:
with get_db_session() as session:
total = count_expressions(session, select(Expression.id))
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
ai_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.AI,
),
)
user_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.USER,
),
)
return ReviewStatsResponse(
total=total,
unchecked=unchecked,
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
@router.get("/review/list", response_model=ReviewListResponse)
async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
order: str = Query("latest", description="排序方式: latest/random"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
exclude_ids: Optional[List[int]] = EXCLUDE_IDS_QUERY,
) -> ReviewListResponse:
"""获取待审核或已审核的表达方式列表。
Args:
page: 页码。
page_size: 每页数量。
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
order: 排序方式,可选 latest 或 random。
search: 搜索关键词。
chat_id: 聊天 ID 筛选条件。
exclude_ids: 需要排除的表达方式 ID。
Returns:
ReviewListResponse: 审核列表响应。
"""
try:
statement = apply_review_filter(select(Expression), filter_type)
# all 不需要额外过滤
# 搜索过滤
if search:
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
if exclude_ids:
statement = statement.where(~col(Expression.id).in_(exclude_ids))
if order == "random":
statement = statement.order_by(func.random())
else:
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = apply_review_filter(select(Expression.id), filter_type)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr, session) for expr in expressions]
return ReviewListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 前端保留的来源标记,人工审核提交时不再阻断覆盖
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
failed: int
results: List[BatchReviewResultItem]
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
) -> BatchReviewResponse:
"""批量审核表达方式。
Args:
request: 批量审核请求。
Returns:
BatchReviewResponse: 每条表达方式的审核结果。
"""
try:
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
results = []
succeeded = 0
failed = 0
for item in request.items:
try:
with get_db_session() as session:
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
if not expression:
results.append(
BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式")
)
failed += 1
continue
# 更新状态
with get_db_session() as session:
db_expression = session.exec(
select(Expression).where(col(Expression.id) == item.id).limit(1)
).first()
if not db_expression:
results.append(
BatchReviewResultItem(
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
)
)
failed += 1
continue
db_expression.checked = True
db_expression.rejected = item.rejected
db_expression.modified_by = ModifiedBy.USER
db_expression.last_active_time = datetime.now()
session.add(db_expression)
results.append(
BatchReviewResultItem(id=item.id, success=True, message="拒绝" if item.rejected else "通过")
)
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e)))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量审核失败: {e}")
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e