重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
This commit is contained in:
@@ -3,11 +3,16 @@
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from sqlalchemy import case
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy import case, func
|
||||
from sqlmodel import col, select, delete
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
@@ -98,30 +103,32 @@ def verify_auth_token(
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
return ExpressionResponse(
|
||||
id=expression.id,
|
||||
id=expression.id if expression.id is not None else 0,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=expression.session_id or "",
|
||||
create_date=create_date,
|
||||
checked=False,
|
||||
rejected=False,
|
||||
modified_by=None,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
# 优先使用群聊名称,否则使用用户昵称
|
||||
if chat_stream.group_name:
|
||||
return chat_stream.group_name
|
||||
elif chat_stream.user_nickname:
|
||||
return chat_stream.user_nickname
|
||||
return chat_id # 找不到时返回原始ID
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return chat_id
|
||||
if chat_stream.group_info and chat_stream.group_info.group_name:
|
||||
return chat_stream.group_info.group_name
|
||||
if chat_stream.user_info and chat_stream.user_info.user_nickname:
|
||||
return chat_stream.user_info.user_nickname
|
||||
return chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
@@ -130,12 +137,15 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称"""
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
||||
for cs in chat_streams:
|
||||
if cs.group_name:
|
||||
result[cs.stream_id] = cs.group_name
|
||||
elif cs.user_nickname:
|
||||
result[cs.stream_id] = cs.user_nickname
|
||||
chat_manager = get_chat_manager()
|
||||
for chat_id in chat_ids:
|
||||
chat_stream = chat_manager.get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
continue
|
||||
if chat_stream.group_info and chat_stream.group_info.group_name:
|
||||
result[chat_id] = chat_stream.group_info.group_name
|
||||
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
|
||||
result[chat_id] = chat_stream.user_info.user_nickname
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
@@ -172,14 +182,17 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
||||
for stream_id, stream in get_chat_manager().streams.items():
|
||||
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
|
||||
if not chat_name and stream.user_info and stream.user_info.user_nickname:
|
||||
chat_name = stream.user_info.user_nickname
|
||||
chat_name = chat_name or stream_id
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=cs.stream_id,
|
||||
chat_id=stream_id,
|
||||
chat_name=chat_name,
|
||||
platform=cs.platform,
|
||||
is_group=bool(cs.group_id),
|
||||
platform=stream.platform,
|
||||
is_group=bool(stream.group_info and stream.group_info.group_id),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -221,29 +234,39 @@ async def get_expression_list(
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
statement = select(Expression)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
|
||||
statement = statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
query = query.order_by(
|
||||
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.last_active_time).is_(None), 1), else_=0),
|
||||
col(Expression.last_active_time).desc(),
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
statement = statement.offset(offset).limit(page_size)
|
||||
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
)
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
@@ -272,7 +295,9 @@ async def get_expression_detail(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
@@ -305,16 +330,22 @@ async def create_expression(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
current_time = datetime.now()
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
expression = Expression(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
context="",
|
||||
up_content="",
|
||||
content_list="[]",
|
||||
count=0,
|
||||
last_active_time=current_time,
|
||||
create_time=current_time,
|
||||
session_id=request.chat_id,
|
||||
)
|
||||
session.add(expression)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
@@ -350,16 +381,18 @@ async def update_expression(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 冲突检测:如果要求未检查状态,但已经被检查了
|
||||
if request.require_unchecked and expression.checked:
|
||||
if request.require_unchecked and getattr(expression, "checked", False):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
|
||||
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
|
||||
)
|
||||
|
||||
# 只更新提供的字段
|
||||
@@ -376,13 +409,18 @@ async def update_expression(
|
||||
update_data["modified_by"] = "user"
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data["last_active_time"] = time.time()
|
||||
update_data["last_active_time"] = datetime.now()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
expression.save()
|
||||
with get_db_session() as session:
|
||||
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
|
||||
if not db_expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
for field, value in update_data.items():
|
||||
if hasattr(db_expression, field):
|
||||
setattr(db_expression, field, value)
|
||||
session.add(db_expression)
|
||||
expression = db_expression
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
@@ -414,7 +452,9 @@ async def delete_expression(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||
expression = session.exec(statement).first()
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
@@ -423,7 +463,8 @@ async def delete_expression(
|
||||
situation = expression.situation
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
with get_db_session() as session:
|
||||
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
@@ -465,8 +506,9 @@ async def batch_delete_expressions(
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
# 查找所有要删除的表达方式
|
||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||
found_ids = [expr.id for expr in expressions]
|
||||
with get_db_session() as session:
|
||||
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
|
||||
found_ids = [expr_id for expr_id in session.exec(statements).all()]
|
||||
|
||||
# 检查是否有未找到的ID
|
||||
not_found_ids = set(request.ids) - set(found_ids)
|
||||
@@ -474,7 +516,9 @@ async def batch_delete_expressions(
|
||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||
|
||||
# 执行批量删除
|
||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||
with get_db_session() as session:
|
||||
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
|
||||
deleted_count = result.rowcount or 0
|
||||
|
||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||
|
||||
@@ -503,21 +547,21 @@ async def get_expression_stats(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
chat_stats = {}
|
||||
for chat_id in session.exec(select(Expression.session_id)).all():
|
||||
if chat_id:
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_statement = (
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
|
||||
)
|
||||
recent = session.exec(recent_statement).one()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -561,12 +605,13 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
unchecked = Expression.select().where(Expression.checked == False).count()
|
||||
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
|
||||
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
|
||||
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
|
||||
user_checked = Expression.select().where(Expression.modified_by == "user").count()
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
passed = 0
|
||||
rejected = 0
|
||||
ai_checked = 0
|
||||
user_checked = 0
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
@@ -620,31 +665,44 @@ async def get_review_list(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
query = Expression.select()
|
||||
statement = select(Expression)
|
||||
|
||||
# 根据筛选类型过滤
|
||||
if filter_type == "unchecked":
|
||||
query = query.where(Expression.checked == False)
|
||||
elif filter_type == "passed":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == False))
|
||||
elif filter_type == "rejected":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == True))
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
statement = statement.where(col(Expression.id) == -1)
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
|
||||
statement = statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||
col(Expression.create_time).desc(),
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
statement = statement.offset(offset).limit(page_size)
|
||||
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
)
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
@@ -720,7 +778,8 @@ async def batch_review_expressions(
|
||||
|
||||
for item in request.items:
|
||||
try:
|
||||
expression = Expression.get_or_none(Expression.id == item.id)
|
||||
with get_db_session() as session:
|
||||
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
|
||||
|
||||
if not expression:
|
||||
results.append(
|
||||
@@ -730,23 +789,28 @@ async def batch_review_expressions(
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked and expression.checked:
|
||||
if item.require_unchecked:
|
||||
results.append(
|
||||
BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
|
||||
)
|
||||
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 更新状态
|
||||
expression.checked = True
|
||||
expression.rejected = item.rejected
|
||||
expression.modified_by = "user"
|
||||
expression.last_active_time = time.time()
|
||||
expression.save()
|
||||
with get_db_session() as session:
|
||||
db_expression = session.exec(
|
||||
select(Expression).where(col(Expression.id) == item.id).limit(1)
|
||||
).first()
|
||||
if not db_expression:
|
||||
results.append(
|
||||
BatchReviewResultItem(
|
||||
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
|
||||
)
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
db_expression.last_active_time = datetime.now()
|
||||
session.add(db_expression)
|
||||
|
||||
results.append(
|
||||
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
|
||||
|
||||
Reference in New Issue
Block a user