diff --git a/src/learners/expression_auto_check_task.py b/src/learners/expression_auto_check_task.py index a860f2a3..d20d2c81 100644 --- a/src/learners/expression_auto_check_task.py +++ b/src/learners/expression_auto_check_task.py @@ -17,9 +17,9 @@ from sqlmodel import select from src.common.database.database import get_db_session from src.common.database.database_model import Expression +from src.common.database.database_model import ModifiedBy from src.common.logger import get_logger from src.config.config import global_config -from src.learners.expression_review_store import get_review_state, set_review_state from src.learners.expression_utils import check_expression_suitability from src.manager.async_task_manager import AsyncTask @@ -53,7 +53,7 @@ class ExpressionAutoCheckTask(AsyncTask): statement = select(Expression) all_expressions = session.exec(statement).all() - unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]] + unevaluated_expressions = [expr for expr in all_expressions if not expr.checked] if not unevaluated_expressions: logger.info("没有未检查的表达方式") @@ -62,9 +62,7 @@ class ExpressionAutoCheckTask(AsyncTask): selected_count = min(count, len(unevaluated_expressions)) selected = random.sample(unevaluated_expressions, selected_count) - logger.info( - f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条" - ) + logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条") return selected except Exception as e: @@ -86,8 +84,18 @@ class ExpressionAutoCheckTask(AsyncTask): expression.style, ) + if error: + logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}") + return False + try: - set_review_state(expression.id, True, not suitable, "ai") + with get_db_session() as session: + expr = session.exec(select(Expression).where(Expression.id == expression.id)).first() + if expr: + expr.checked = True + expr.rejected = not suitable + expr.modified_by = ModifiedBy.AI + session.add(expr) status = "通过" if suitable else "不通过" # 保留这段注释,方便后续需要时恢复更详细的审核日志。 @@ -98,9 +106,6 @@ class ExpressionAutoCheckTask(AsyncTask): # f"Reason: {reason[:50]}..." # ) - if error: - logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}") - logger.debug(f"表达方式 [ID: {expression.id}] 评估完成: {status}, reason={reason}") return suitable diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index 7283c350..0b131b4d 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -96,10 +96,7 @@ def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional """从最近消息中解析聊天显示名称。""" statement = ( - select(Messages) - .where(col(Messages.session_id) == chat_id) - .order_by(col(Messages.timestamp).desc()) - .limit(1) + select(Messages).where(col(Messages.session_id) == chat_id).order_by(col(Messages.timestamp).desc()).limit(1) ) message = db_session.exec(statement).first() if not message: @@ -236,9 +233,7 @@ async def get_chat_list() -> ChatListResponse: is_group=bool(chat_session.group_id), ) - expression_chat_ids = { - chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id - } + expression_chat_ids = {chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id} for session_id in expression_chat_ids: if session_id in chat_by_id: continue