merge: 同步上游 dev 最新内容
This commit is contained in:
@@ -112,7 +112,7 @@ class ChatHistoryManager:
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"content": msg.processed_plain_text or "",
|
||||
"timestamp": msg.timestamp.timestamp(),
|
||||
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
@@ -175,11 +175,7 @@ class ChatHistoryManager:
|
||||
|
||||
user_info = target_msg.message_info.user_info
|
||||
if not has_content:
|
||||
content_text = (
|
||||
target_msg.processed_plain_text
|
||||
or target_msg.display_message
|
||||
or ""
|
||||
)
|
||||
content_text = target_msg.processed_plain_text or ""
|
||||
data["target_message_content"] = content_text
|
||||
if not has_sender:
|
||||
data["target_message_sender_id"] = user_info.user_id or ""
|
||||
|
||||
@@ -2,19 +2,21 @@
|
||||
配置管理API路由
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Tuple, Union, get_args, get_origin
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Tuple
|
||||
import types
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import list_prompt_templates
|
||||
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
|
||||
from src.config.config_base import AttributeData
|
||||
from src.config.config_base import AttributeData, ConfigBase
|
||||
from src.config.model_configs import (
|
||||
APIProvider,
|
||||
ModelInfo,
|
||||
@@ -63,6 +65,9 @@ class PromptFileInfo(BaseModel):
|
||||
name: str = Field(..., description="Prompt 文件名")
|
||||
size: int = Field(..., description="文件大小")
|
||||
modified_at: float = Field(..., description="最后修改时间戳")
|
||||
display_name: str = Field(default="", description="Prompt 展示名称")
|
||||
advanced: bool = Field(default=False, description="是否为高级 Prompt")
|
||||
description: str = Field(default="", description="Prompt 描述")
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
@@ -129,6 +134,71 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _coerce_numeric_value(value: Any, target_type: Any) -> Any:
|
||||
"""根据配置字段类型,把旧 WebUI 可能写入的数字字符串还原为数字。"""
|
||||
if target_type is str:
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
if target_type is int:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = float(value.strip())
|
||||
except ValueError:
|
||||
return value
|
||||
if parsed_value.is_integer():
|
||||
return int(parsed_value)
|
||||
return value
|
||||
|
||||
if target_type is float:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value.strip())
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_value_by_annotation(value: Any, annotation: Any) -> Any:
|
||||
"""递归按 ConfigBase 字段注解修正数据类型,避免保存时把数字写成字符串。"""
|
||||
value = _coerce_numeric_value(value, annotation)
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in {Union, types.UnionType}:
|
||||
for candidate_type in args:
|
||||
if candidate_type is type(None):
|
||||
continue
|
||||
coerced_value = _coerce_value_by_annotation(value, candidate_type)
|
||||
if coerced_value != value or type(coerced_value) is not type(value):
|
||||
return coerced_value
|
||||
return value
|
||||
|
||||
if origin in {list, List} and isinstance(value, list) and args:
|
||||
item_type = args[0]
|
||||
return [_coerce_value_by_annotation(item, item_type) for item in value]
|
||||
|
||||
if origin in {dict, Dict} and isinstance(value, dict) and len(args) >= 2:
|
||||
value_type = args[1]
|
||||
return {key: _coerce_value_by_annotation(item, value_type) for key, item in value.items()}
|
||||
|
||||
if isinstance(value, dict) and isinstance(annotation, type) and issubclass(annotation, ConfigBase):
|
||||
return _coerce_config_numeric_values(value, annotation)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_config_numeric_values(data: Dict[str, Any], config_type: type[ConfigBase]) -> Dict[str, Any]:
|
||||
"""按配置类 schema 统一修正所有数字字段类型。"""
|
||||
for field_name, field_info in config_type.model_fields.items():
|
||||
if field_name in data:
|
||||
data[field_name] = _coerce_value_by_annotation(data[field_name], field_info.annotation)
|
||||
return data
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@@ -147,14 +217,20 @@ async def list_prompt_files():
|
||||
continue
|
||||
|
||||
language = language_dir.name
|
||||
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
|
||||
prompt_files: List[PromptFileInfo] = []
|
||||
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
|
||||
stat = prompt_file.stat()
|
||||
template_info = prompt_template_infos.get(prompt_file.stem)
|
||||
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
|
||||
prompt_files.append(
|
||||
PromptFileInfo(
|
||||
name=prompt_file.name,
|
||||
size=stat.st_size,
|
||||
modified_at=stat.st_mtime,
|
||||
display_name=metadata.display_name if metadata else "",
|
||||
advanced=metadata.advanced if metadata else False,
|
||||
description=metadata.description if metadata else "",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -347,6 +423,8 @@ async def get_model_config():
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
config_data = _coerce_config_numeric_values(config_data, Config)
|
||||
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(AttributeData(), copy.deepcopy(config_data))
|
||||
@@ -370,6 +448,8 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
config_data = _coerce_config_numeric_values(config_data, ModelConfig)
|
||||
|
||||
# 验证配置数据
|
||||
try:
|
||||
ModelConfig.from_dict(AttributeData(), copy.deepcopy(config_data))
|
||||
@@ -422,10 +502,13 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
Config.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
|
||||
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), Config)
|
||||
Config.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
config_data = plain_config_data
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -520,13 +603,14 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
ModelConfig.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
|
||||
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), ModelConfig)
|
||||
ModelConfig.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
models = plain_config_data.get("models", [])
|
||||
orphaned_models: List[str] = [
|
||||
str(model_name)
|
||||
for m in models
|
||||
@@ -539,6 +623,8 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
config_data = plain_config_data
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
|
||||
@@ -10,11 +10,12 @@ from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
EXCLUDE_IDS_QUERY = Query(None, description="需要排除的表达方式 ID")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
|
||||
@@ -28,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
@@ -90,7 +92,61 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
|
||||
"""从最近消息中解析聊天显示名称。"""
|
||||
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.session_id) == chat_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
message = db_session.exec(statement).first()
|
||||
if not message:
|
||||
return None
|
||||
if message.group_id:
|
||||
return message.group_name or f"群聊{message.group_id}"
|
||||
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
|
||||
|
||||
|
||||
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
|
||||
"""从会话记录推断兜底显示名称。"""
|
||||
|
||||
if chat_session.group_id:
|
||||
return f"群聊{chat_session.group_id}"
|
||||
if chat_session.user_id:
|
||||
return f"用户{chat_session.user_id}"
|
||||
return chat_session.session_id
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
|
||||
try:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
return name
|
||||
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
|
||||
return name
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if session:
|
||||
if session.group_id:
|
||||
return f"群聊{session.group_id}"
|
||||
if session.user_id:
|
||||
return f"用户{session.user_id}"
|
||||
return chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
|
||||
"""将表达方式模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
@@ -101,38 +157,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
chat_id = expression.session_id or ""
|
||||
return ExpressionResponse(
|
||||
id=expression.id if expression.id is not None else 0,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=expression.session_id or "",
|
||||
chat_id=chat_id,
|
||||
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
|
||||
create_date=create_date,
|
||||
checked=False,
|
||||
rejected=False,
|
||||
modified_by=None,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by.value if expression.modified_by else None,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
try:
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
return chat_id
|
||||
name = _chat_manager.get_session_name(chat_id)
|
||||
return name or chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称。
|
||||
|
||||
@@ -145,8 +184,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
for chat_id in chat_ids:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
result[chat_id] = name
|
||||
result[chat_id] = get_chat_name(chat_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
@@ -176,19 +214,43 @@ async def get_chat_list() -> ChatListResponse:
|
||||
ChatListResponse: 可用于下拉选择的聊天列表。
|
||||
"""
|
||||
try:
|
||||
chat_list = []
|
||||
chat_by_id: Dict[str, ChatInfo] = {}
|
||||
for session_id, session in _chat_manager.sessions.items():
|
||||
chat_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
for chat_session in session.exec(select(ChatSession)).all():
|
||||
if chat_session.session_id in chat_by_id:
|
||||
continue
|
||||
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
|
||||
chat_by_id[chat_session.session_id] = ChatInfo(
|
||||
chat_id=chat_session.session_id,
|
||||
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
|
||||
platform=chat_session.platform,
|
||||
is_group=bool(chat_session.group_id),
|
||||
)
|
||||
|
||||
expression_chat_ids = {
|
||||
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
|
||||
}
|
||||
for session_id in expression_chat_ids:
|
||||
if session_id in chat_by_id:
|
||||
continue
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=get_chat_name(session_id, session),
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list = list(chat_by_id.values())
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
@@ -252,7 +314,7 @@ async def get_expression_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -281,7 +343,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=data)
|
||||
|
||||
@@ -321,7 +383,7 @@ async def create_expression(
|
||||
session.add(expression)
|
||||
session.flush()
|
||||
expression_id = expression.id
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
|
||||
|
||||
@@ -375,7 +437,7 @@ async def update_expression(
|
||||
db_expression.session_id = update_data["session_id"]
|
||||
db_expression.last_active_time = update_data["last_active_time"]
|
||||
session.add(db_expression)
|
||||
data = expression_to_response(db_expression)
|
||||
data = expression_to_response(db_expression, session)
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
@@ -524,6 +586,22 @@ class ReviewStatsResponse(BaseModel):
|
||||
user_checked: int
|
||||
|
||||
|
||||
def apply_review_filter(statement: Any, filter_type: str) -> Any:
|
||||
"""按审核状态过滤表达方式查询。"""
|
||||
if filter_type == "unchecked":
|
||||
return statement.where(col(Expression.checked).is_(False))
|
||||
if filter_type == "passed":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
|
||||
if filter_type == "rejected":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
|
||||
return statement
|
||||
|
||||
|
||||
def count_expressions(session: Any, statement: Any) -> int:
|
||||
"""统计表达方式查询结果数量。"""
|
||||
return len(session.exec(statement).all())
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""获取审核统计数据。
|
||||
@@ -533,12 +611,24 @@ async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
passed = 0
|
||||
rejected = 0
|
||||
ai_checked = 0
|
||||
user_checked = 0
|
||||
total = count_expressions(session, select(Expression.id))
|
||||
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
|
||||
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
|
||||
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
|
||||
ai_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.AI,
|
||||
),
|
||||
)
|
||||
user_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.USER,
|
||||
),
|
||||
)
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
@@ -571,8 +661,10 @@ async def get_review_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
order: str = Query("latest", description="排序方式: latest/random"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
exclude_ids: Optional[List[int]] = EXCLUDE_IDS_QUERY,
|
||||
) -> ReviewListResponse:
|
||||
"""获取待审核或已审核的表达方式列表。
|
||||
|
||||
@@ -580,17 +672,16 @@ async def get_review_list(
|
||||
page: 页码。
|
||||
page_size: 每页数量。
|
||||
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
|
||||
order: 排序方式,可选 latest 或 random。
|
||||
search: 搜索关键词。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
exclude_ids: 需要排除的表达方式 ID。
|
||||
|
||||
Returns:
|
||||
ReviewListResponse: 审核列表响应。
|
||||
"""
|
||||
try:
|
||||
statement = select(Expression)
|
||||
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
statement = statement.where(col(Expression.id) == -1)
|
||||
statement = apply_review_filter(select(Expression), filter_type)
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
@@ -603,11 +694,17 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||
col(Expression.create_time).desc(),
|
||||
)
|
||||
if exclude_ids:
|
||||
statement = statement.where(~col(Expression.id).in_(exclude_ids))
|
||||
|
||||
if order == "random":
|
||||
statement = statement.order_by(func.random())
|
||||
else:
|
||||
# 排序:创建时间倒序
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||
col(Expression.create_time).desc(),
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
statement = statement.offset(offset).limit(page_size)
|
||||
@@ -615,9 +712,7 @@ async def get_review_list(
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||
count_statement = apply_review_filter(select(Expression.id), filter_type)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
@@ -625,7 +720,7 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
@@ -647,7 +742,7 @@ class BatchReviewItem(BaseModel):
|
||||
|
||||
id: int
|
||||
rejected: bool
|
||||
require_unchecked: bool = True # 默认要求未检查状态
|
||||
require_unchecked: bool = True # 前端保留的来源标记,人工审核提交时不再阻断覆盖
|
||||
|
||||
|
||||
class BatchReviewRequest(BaseModel):
|
||||
@@ -706,14 +801,6 @@ async def batch_review_expressions(
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked:
|
||||
results.append(
|
||||
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 更新状态
|
||||
with get_db_session() as session:
|
||||
db_expression = session.exec(
|
||||
@@ -727,6 +814,9 @@ async def batch_review_expressions(
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
db_expression.checked = True
|
||||
db_expression.rejected = item.rejected
|
||||
db_expression.modified_by = ModifiedBy.USER
|
||||
db_expression.last_active_time = datetime.now()
|
||||
session.add(db_expression)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user