feat:webui支持更加优化的模型配置,优化多处UI体验,支持设置视觉和cache价格,修复多重表达不生效的问题,修复表情包路径错误

This commit is contained in:
SengokuCola
2026-05-04 22:52:41 +08:00
parent 14b7bc78a2
commit eea95c1961
38 changed files with 1188 additions and 454 deletions

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import json
import random
import time
from rich.console import Group, RenderableType
@@ -103,6 +104,24 @@ class BaseMaisakaReplyGenerator:
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
return "你的名字是麦麦。\n是人类。"
@staticmethod
def _select_reply_style() -> str:
"""按配置概率选择本次 replyer 使用的表达风格。"""
personality_config = global_config.personality
reply_style = personality_config.reply_style
candidate_styles = [style.strip() for style in personality_config.multiple_reply_style if style.strip()]
if not candidate_styles:
return reply_style
probability = personality_config.multiple_probability
if probability <= 0:
return reply_style
if random.random() > probability:
return reply_style
return random.choice(candidate_styles)
@staticmethod
def _normalize_content(content: str, limit: int = 500) -> str:
normalized = " ".join((content or "").split())
@@ -293,7 +312,7 @@ class BaseMaisakaReplyGenerator:
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
replyer_at_block=self._build_replyer_at_block(),
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
reply_style=self._select_reply_style(),
)
except Exception:
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"

View File

@@ -58,7 +58,7 @@ LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0-pre.10"
CONFIG_VERSION: str = "8.10.6"
MODEL_CONFIG_VERSION: str = "1.14.8"
MODEL_CONFIG_VERSION: str = "1.15.3"
logger = get_logger("config")

View File

@@ -135,7 +135,6 @@ class ConfigBase(BaseModel, AttrDocBase):
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
__ui_icon__: ClassVar[str] = "" # Tab 图标名称Lucide 图标名)
__ui_merge_children__: ClassVar[List[str]] = [] # 在 WebUI 中并入当前配置卡片展示的子配置字段名
@classmethod
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):

View File

@@ -402,6 +402,7 @@ class TaskConfig(ConfigBase):
"x-widget": "input",
"x-icon": "alert-circle",
"step": 0.1,
"advanced": True,
},
)
"""慢请求阈值(秒),超过此值会输出警告日志"""
@@ -420,15 +421,6 @@ class TaskConfig(ConfigBase):
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
replyer: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
@@ -436,17 +428,7 @@ class ModelTaskConfig(ConfigBase):
"x-icon": "message-square",
},
)
"""首要回复模型配置"""
learner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "graduation-cap",
"advanced": True,
},
)
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
"""回复模型配置"""
planner: TaskConfig = Field(
default_factory=TaskConfig,
@@ -457,6 +439,25 @@ class ModelTaskConfig(ConfigBase):
)
"""规划模型配置"""
utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
learner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "graduation-cap",
"advanced": True,
},
)
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
vlm: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
@@ -471,6 +472,7 @@ class ModelTaskConfig(ConfigBase):
json_schema_extra={
"x-widget": "custom",
"x-icon": "volume-2",
"advanced": True,
},
)
"""语音识别模型配置"""

View File

@@ -84,6 +84,8 @@ class PersonalityConfig(ConfigBase):
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user-circle",
"x-textarea-min-height": 40,
"x-textarea-rows": 1,
},
)
"""人格建议200字以内描述人格特质和身份特征可以写完整设定。要求第二人称"""
@@ -93,6 +95,8 @@ class PersonalityConfig(ConfigBase):
json_schema_extra={
"x-widget": "textarea",
"x-icon": "message-square",
"x-textarea-min-height": 40,
"x-textarea-rows": 1,
},
)
"""默认表达风格描述麦麦说话的表达风格表达习惯如要修改可以酌情新增内容建议1-2行"""
@@ -321,7 +325,8 @@ class ChatConfig(ConfigBase):
class MessageReceiveConfig(ConfigBase):
"""消息接收配置类"""
__ui_parent__ = "response_post_process"
__ui_label__ = "消息接收"
__ui_icon__ = "message-square-text"
image_parse_threshold: int = Field(
default=5,
@@ -394,7 +399,7 @@ class TargetItem(ConfigBase):
class MemoryConfig(ConfigBase):
"""记忆配置类"""
__ui_parent__ = "emoji"
__ui_parent__ = "a_memorix"
global_memory: bool = Field(
@@ -1051,7 +1056,7 @@ class LearningItem(ConfigBase):
"x-icon": "message-square",
},
)
"""是否用表达学习"""
"""是否使用表达"""
enable_learning: bool = Field(
default=True,
@@ -1060,7 +1065,7 @@ class LearningItem(ConfigBase):
"x-icon": "graduation-cap",
},
)
"""是否启用表达优化学习"""
"""是否学习表达"""
enable_jargon_learning: bool = Field(
default=False,
@@ -1069,7 +1074,7 @@ class LearningItem(ConfigBase):
"x-icon": "book",
},
)
"""是否启用jargon学习"""
"""是否学习黑话"""
class ExpressionGroup(ConfigBase):
"""表达互通组配置类,若列表为空代表全局共享"""
@@ -1177,7 +1182,8 @@ class ExpressionConfig(ConfigBase):
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
__ui_parent__ = "emoji"
__ui_label__ = "语音"
__ui_icon__ = "mic"
enable_asr: bool = Field(
default=False,
@@ -1192,8 +1198,8 @@ class VoiceConfig(ConfigBase):
class EmojiConfig(ConfigBase):
"""表情包配置类"""
__ui_label__ = "功能"
__ui_icon__ = "puzzle"
__ui_label__ = "表情包"
__ui_icon__ = "smile"
emoji_send_num: int = Field(
default=25,
@@ -1314,7 +1320,7 @@ class KeywordRuleConfig(ConfigBase):
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
__ui_parent__ = "response_post_process"
__ui_parent__ = "message_receive"
keyword_rules: list[KeywordRuleConfig] = Field(
default_factory=lambda: [],
@@ -1345,9 +1351,8 @@ class KeywordReactionConfig(ConfigBase):
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
__ui_label__ = "处理"
__ui_label__ = "处理"
__ui_icon__ = "settings"
__ui_merge_children__ = ["chinese_typo", "response_splitter"]
enable_response_post_process: bool = Field(
default=True,

View File

@@ -29,7 +29,7 @@ logger = get_logger("emoji")
install(extra_lines=3)
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
PROJECT_ROOT = Path(__file__).resolve().parents[2]
DATA_DIR = PROJECT_ROOT / "data"
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
@@ -215,6 +215,17 @@ def _is_available_emoji_record(record: Images) -> bool:
return record_path.exists() and record_path.is_file()
def _resolve_existing_emoji_path(raw_path: str | Path | None) -> Optional[Path]:
"""将表情包路径归一化;路径为空或不存在时返回 ``None``。"""
if not raw_path:
return None
record_path = Path(raw_path).absolute().resolve()
if not record_path.exists() or not record_path.is_file():
return None
return record_path
def _is_vlm_task_configured() -> bool:
"""判断是否配置了可用于表情包识别和审核的视觉模型任务。"""
@@ -244,6 +255,7 @@ class EmojiManager:
self._emoji_num: int = 0
self.emojis: list[MaiEmoji] = []
self._known_emoji_file_paths: set[Path] = set()
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
self._reload_callback_registered: bool = False
@@ -254,9 +266,9 @@ class EmojiManager:
logger.info("启动表情包管理器")
def reload_runtime_config(self) -> None:
"""响应配置热重载,唤醒维护循环以尽快应用最新配置。"""
"""响应配置热重载,重置维护循环等待时间以应用最新配置。"""
self._maintenance_wakeup_event.set()
logger.info("[配置热重载] Emoji 模块配置已更新,将立即应用到维护循环")
logger.info("[配置热重载] Emoji 模块配置已更新,将按新的检查间隔等待后执行维护")
def shutdown(self) -> None:
"""清理 EmojiManager 生命周期资源。"""
@@ -948,33 +960,73 @@ class EmojiManager:
def check_emoji_file_integrity(self) -> None:
"""
检查表情包完整性,删除文件缺失的表情包记录
检查表情包文件和数据库注册记录的一致性。
数据库记录存在但文件缺失时删除数据库记录;文件存在但没有数据库记录时删除文件。
"""
logger.info("[完整性检查] 开始检查表情包文件完整性...")
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
removal_count = 0
for emoji in self.emojis:
if not emoji.full_path.exists():
logger.warning(f"[完整性检查] 表情包文件缺失,准备修改记录: {emoji.file_name}")
to_delete_emojis.append((emoji, False))
if not emoji.description:
logger.warning(f"[完整性检查] 表情包记录缺失描述,准备删除记录: {emoji.file_name}")
to_delete_emojis.append((emoji, True))
_ensure_directories()
logger.info("[完整性检查] 开始检查表情包文件和注册记录一致性...")
tracked_paths: set[Path] = set()
record_removal_count = 0
file_removal_count = 0
available_emojis: list[MaiEmoji] = []
for emoji, is_description_empty in to_delete_emojis:
if self.delete_emoji(emoji, is_description_empty):
self.emojis.remove(emoji)
self._emoji_num -= 1
removal_count += 1
logger.info(f"[完整性检查] 成功删除缺失文件的表情包记录: {emoji.file_name}")
else:
logger.error(f"[完整性检查] 删除缺失文件的表情包记录失败: {emoji.file_name}")
with get_db_session() as session:
statement = select(Images).filter_by(image_type=ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
record_path = _resolve_existing_emoji_path(record.full_path)
if record.no_file_flag or record_path is None:
logger.warning(
f"[完整性检查] 表情包数据库记录缺少实际文件,删除数据库记录: id={record.id}, path={record.full_path}"
)
session.delete(record)
record_removal_count += 1
continue
if not record.is_registered or record.is_banned:
tracked_paths.add(record_path)
continue
try:
available_emojis.append(MaiEmoji.from_db_instance(record))
tracked_paths.add(record_path)
except Exception as exc:
logger.error(
f"[完整性检查] 加载表情包记录时出错,将删除异常记录: {exc}\n记录ID: {record.id}, 路径: {record.full_path}"
)
session.delete(record)
record_removal_count += 1
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
resolved_file = emoji_file.absolute().resolve()
if resolved_file in tracked_paths:
continue
try:
emoji_file.unlink()
file_removal_count += 1
logger.warning(f"[完整性检查] 表情包文件缺少数据库记录,删除文件: {emoji_file}")
except Exception as exc:
logger.error(f"[完整性检查] 删除无注册记录的表情包文件失败: {emoji_file}, error={exc}")
self.emojis = available_emojis
self._known_emoji_file_paths = tracked_paths
self._emoji_num = len(self.emojis)
logger.info(
f"[完整性检查] 表情包完整性检查完成,删除数据库记录 {record_removal_count} 条,删除图片文件 {file_removal_count}"
)
async def periodic_emoji_maintenance(self) -> None:
"""Run emoji maintenance tasks periodically."""
while True:
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
self._maintenance_wakeup_event.clear()
continue
except asyncio.TimeoutError:
self._maintenance_wakeup_event.clear()
_ensure_directories()
try:
self.check_emoji_file_integrity()
@@ -989,24 +1041,21 @@ class EmojiManager:
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
resolved_file = emoji_file.absolute().resolve()
if resolved_file in self._known_emoji_file_paths:
continue
try:
register_status = await self.register_emoji_by_filename(emoji_file)
except Exception as e:
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
register_status = "failed"
if register_status == "registered":
self._known_emoji_file_paths.add(resolved_file)
break
if register_status == "skipped":
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
else:
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
except asyncio.TimeoutError:
pass
finally:
self._maintenance_wakeup_event.clear()
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""

View File

@@ -39,15 +39,12 @@ class ConfigSchemaGenerator:
ui_parent = getattr(config_class, "__ui_parent__", "")
ui_label = getattr(config_class, "__ui_label__", "")
ui_icon = getattr(config_class, "__ui_icon__", "")
ui_merge_children = getattr(config_class, "__ui_merge_children__", [])
if ui_parent:
schema["uiParent"] = ui_parent
if ui_label:
schema["uiLabel"] = ui_label
if ui_icon:
schema["uiIcon"] = ui_icon
if ui_merge_children:
schema["uiMergeChildren"] = list(ui_merge_children)
return schema

View File

@@ -10,7 +10,7 @@ from sqlmodel import col, delete, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
@@ -28,6 +28,7 @@ class ExpressionResponse(BaseModel):
style: str
last_active_time: float
chat_id: str
chat_name: Optional[str] = None
create_date: Optional[float]
checked: bool
rejected: bool
@@ -90,7 +91,61 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
"""从最近消息中解析聊天显示名称。"""
statement = (
select(Messages)
.where(col(Messages.session_id) == chat_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
message = db_session.exec(statement).first()
if not message:
return None
if message.group_id:
return message.group_name or f"群聊{message.group_id}"
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
"""从会话记录推断兜底显示名称。"""
if chat_session.group_id:
return f"群聊{chat_session.group_id}"
if chat_session.user_id:
return f"用户{chat_session.user_id}"
return chat_session.session_id
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
if name := _chat_manager.get_session_name(chat_id):
return name
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
return name
session = _chat_manager.get_session_by_session_id(chat_id)
if session:
if session.group_id:
return f"群聊{session.group_id}"
if session.user_id:
return f"用户{session.user_id}"
return chat_id
except Exception:
return chat_id
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
"""将表达方式模型转换为响应对象。
Args:
@@ -101,38 +156,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
chat_id = expression.session_id or ""
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
chat_id=chat_id,
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by.value if expression.modified_by else None,
)
def get_chat_name(chat_id: str) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
return chat_id
name = _chat_manager.get_session_name(chat_id)
return name or chat_id
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称。
@@ -145,8 +183,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
for chat_id in chat_ids:
if name := _chat_manager.get_session_name(chat_id):
result[chat_id] = name
result[chat_id] = get_chat_name(chat_id)
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -176,19 +213,43 @@ async def get_chat_list() -> ChatListResponse:
ChatListResponse: 可用于下拉选择的聊天列表。
"""
try:
chat_list = []
chat_by_id: Dict[str, ChatInfo] = {}
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_list.append(
ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
with get_db_session() as session:
for chat_session in session.exec(select(ChatSession)).all():
if chat_session.session_id in chat_by_id:
continue
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
chat_by_id[chat_session.session_id] = ChatInfo(
chat_id=chat_session.session_id,
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
expression_chat_ids = {
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
}
for session_id in expression_chat_ids:
if session_id in chat_by_id:
continue
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=get_chat_name(session_id, session),
platform=None,
is_group=False,
)
# 按名称排序
chat_list = list(chat_by_id.values())
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
@@ -252,7 +313,7 @@ async def get_expression_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@@ -281,7 +342,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
data = expression_to_response(expression)
data = expression_to_response(expression, session)
return ExpressionDetailResponse(success=True, data=data)
@@ -321,7 +382,7 @@ async def create_expression(
session.add(expression)
session.flush()
expression_id = expression.id
data = expression_to_response(expression)
data = expression_to_response(expression, session)
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
@@ -375,7 +436,7 @@ async def update_expression(
db_expression.session_id = update_data["session_id"]
db_expression.last_active_time = update_data["last_active_time"]
session.add(db_expression)
data = expression_to_response(db_expression)
data = expression_to_response(db_expression, session)
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@@ -524,6 +585,22 @@ class ReviewStatsResponse(BaseModel):
user_checked: int
def apply_review_filter(statement: Any, filter_type: str) -> Any:
"""按审核状态过滤表达方式查询。"""
if filter_type == "unchecked":
return statement.where(col(Expression.checked).is_(False))
if filter_type == "passed":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
if filter_type == "rejected":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
return statement
def count_expressions(session: Any, statement: Any) -> int:
"""统计表达方式查询结果数量。"""
return len(session.exec(statement).all())
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats() -> ReviewStatsResponse:
"""获取审核统计数据。
@@ -533,12 +610,24 @@ async def get_review_stats() -> ReviewStatsResponse:
"""
try:
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
total = count_expressions(session, select(Expression.id))
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
ai_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.AI,
),
)
user_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.USER,
),
)
return ReviewStatsResponse(
total=total,
@@ -587,10 +676,7 @@ async def get_review_list(
ReviewListResponse: 审核列表响应。
"""
try:
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
statement = apply_review_filter(select(Expression), filter_type)
# all 不需要额外过滤
# 搜索过滤
@@ -615,9 +701,7 @@ async def get_review_list(
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
count_statement = apply_review_filter(select(Expression.id), filter_type)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
@@ -625,7 +709,7 @@ async def get_review_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ReviewListResponse(
success=True,
@@ -706,10 +790,10 @@ async def batch_review_expressions(
failed += 1
continue
# 冲突检测
if item.require_unchecked:
# 冲突检测:未审核列表发起的操作只允许处理仍处于未审核状态的条目。
if item.require_unchecked and expression.checked:
results.append(
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
BatchReviewResultItem(id=item.id, success=False, message="该表达方式已被审核,请刷新列表后重试")
)
failed += 1
continue
@@ -727,6 +811,9 @@ async def batch_review_expressions(
)
failed += 1
continue
db_expression.checked = True
db_expression.rejected = item.rejected
db_expression.modified_by = ModifiedBy.USER
db_expression.last_active_time = datetime.now()
session.add(db_expression)