feat:webui支持更加优化的模型配置,优化多处UI体验,支持设置视觉和cache价格,修复多重表达不生效的问题,修复表情包路径错误
This commit is contained in:
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
@@ -103,6 +104,24 @@ class BaseMaisakaReplyGenerator:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦。\n是人类。"
|
||||
|
||||
@staticmethod
|
||||
def _select_reply_style() -> str:
|
||||
"""按配置概率选择本次 replyer 使用的表达风格。"""
|
||||
personality_config = global_config.personality
|
||||
reply_style = personality_config.reply_style
|
||||
candidate_styles = [style.strip() for style in personality_config.multiple_reply_style if style.strip()]
|
||||
|
||||
if not candidate_styles:
|
||||
return reply_style
|
||||
|
||||
probability = personality_config.multiple_probability
|
||||
if probability <= 0:
|
||||
return reply_style
|
||||
if random.random() > probability:
|
||||
return reply_style
|
||||
|
||||
return random.choice(candidate_styles)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
@@ -293,7 +312,7 @@ class BaseMaisakaReplyGenerator:
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
|
||||
replyer_at_block=self._build_replyer_at_block(),
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=self._select_reply_style(),
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
@@ -58,7 +58,7 @@ LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0-pre.10"
|
||||
CONFIG_VERSION: str = "8.10.6"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.8"
|
||||
MODEL_CONFIG_VERSION: str = "1.15.3"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
|
||||
@@ -135,7 +135,6 @@ class ConfigBase(BaseModel, AttrDocBase):
|
||||
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
|
||||
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
|
||||
__ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名)
|
||||
__ui_merge_children__: ClassVar[List[str]] = [] # 在 WebUI 中并入当前配置卡片展示的子配置字段名
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
||||
|
||||
@@ -402,6 +402,7 @@ class TaskConfig(ConfigBase):
|
||||
"x-widget": "input",
|
||||
"x-icon": "alert-circle",
|
||||
"step": 0.1,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
@@ -420,15 +421,6 @@ class TaskConfig(ConfigBase):
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "wrench",
|
||||
},
|
||||
)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
replyer: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
@@ -436,17 +428,7 @@ class ModelTaskConfig(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""首要回复模型配置"""
|
||||
|
||||
learner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "graduation-cap",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
|
||||
"""回复模型配置"""
|
||||
|
||||
planner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
@@ -457,6 +439,25 @@ class ModelTaskConfig(ConfigBase):
|
||||
)
|
||||
"""规划模型配置"""
|
||||
|
||||
utils: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "wrench",
|
||||
},
|
||||
)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
learner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "graduation-cap",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
|
||||
|
||||
vlm: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
@@ -471,6 +472,7 @@ class ModelTaskConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "volume-2",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""语音识别模型配置"""
|
||||
|
||||
@@ -84,6 +84,8 @@ class PersonalityConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user-circle",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
},
|
||||
)
|
||||
"""人格,建议200字以内,描述人格特质和身份特征;可以写完整设定。要求第二人称"""
|
||||
@@ -93,6 +95,8 @@ class PersonalityConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "message-square",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
},
|
||||
)
|
||||
"""默认表达风格,描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行"""
|
||||
@@ -321,7 +325,8 @@ class ChatConfig(ConfigBase):
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
"""消息接收配置类"""
|
||||
|
||||
__ui_parent__ = "response_post_process"
|
||||
__ui_label__ = "消息接收"
|
||||
__ui_icon__ = "message-square-text"
|
||||
|
||||
image_parse_threshold: int = Field(
|
||||
default=5,
|
||||
@@ -394,7 +399,7 @@ class TargetItem(ConfigBase):
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
__ui_parent__ = "a_memorix"
|
||||
|
||||
|
||||
global_memory: bool = Field(
|
||||
@@ -1051,7 +1056,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""是否启用表达学习"""
|
||||
"""是否使用表达"""
|
||||
|
||||
enable_learning: bool = Field(
|
||||
default=True,
|
||||
@@ -1060,7 +1065,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "graduation-cap",
|
||||
},
|
||||
)
|
||||
"""是否启用表达优化学习"""
|
||||
"""是否学习表达"""
|
||||
|
||||
enable_jargon_learning: bool = Field(
|
||||
default=False,
|
||||
@@ -1069,7 +1074,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "book",
|
||||
},
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
"""是否学习黑话"""
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -1177,7 +1182,8 @@ class ExpressionConfig(ConfigBase):
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
__ui_label__ = "语音"
|
||||
__ui_icon__ = "mic"
|
||||
|
||||
enable_asr: bool = Field(
|
||||
default=False,
|
||||
@@ -1192,8 +1198,8 @@ class VoiceConfig(ConfigBase):
|
||||
class EmojiConfig(ConfigBase):
|
||||
"""表情包配置类"""
|
||||
|
||||
__ui_label__ = "功能"
|
||||
__ui_icon__ = "puzzle"
|
||||
__ui_label__ = "表情包"
|
||||
__ui_icon__ = "smile"
|
||||
|
||||
emoji_send_num: int = Field(
|
||||
default=25,
|
||||
@@ -1314,7 +1320,7 @@ class KeywordRuleConfig(ConfigBase):
|
||||
class KeywordReactionConfig(ConfigBase):
|
||||
"""关键词配置类"""
|
||||
|
||||
__ui_parent__ = "response_post_process"
|
||||
__ui_parent__ = "message_receive"
|
||||
|
||||
keyword_rules: list[KeywordRuleConfig] = Field(
|
||||
default_factory=lambda: [],
|
||||
@@ -1345,9 +1351,8 @@ class KeywordReactionConfig(ConfigBase):
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
|
||||
__ui_label__ = "处理"
|
||||
__ui_label__ = "后处理"
|
||||
__ui_icon__ = "settings"
|
||||
__ui_merge_children__ = ["chinese_typo", "response_splitter"]
|
||||
|
||||
enable_response_post_process: bool = Field(
|
||||
default=True,
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = get_logger("emoji")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
|
||||
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
||||
@@ -215,6 +215,17 @@ def _is_available_emoji_record(record: Images) -> bool:
|
||||
return record_path.exists() and record_path.is_file()
|
||||
|
||||
|
||||
def _resolve_existing_emoji_path(raw_path: str | Path | None) -> Optional[Path]:
|
||||
"""将表情包路径归一化;路径为空或不存在时返回 ``None``。"""
|
||||
if not raw_path:
|
||||
return None
|
||||
|
||||
record_path = Path(raw_path).absolute().resolve()
|
||||
if not record_path.exists() or not record_path.is_file():
|
||||
return None
|
||||
return record_path
|
||||
|
||||
|
||||
def _is_vlm_task_configured() -> bool:
|
||||
"""判断是否配置了可用于表情包识别和审核的视觉模型任务。"""
|
||||
|
||||
@@ -244,6 +255,7 @@ class EmojiManager:
|
||||
|
||||
self._emoji_num: int = 0
|
||||
self.emojis: list[MaiEmoji] = []
|
||||
self._known_emoji_file_paths: set[Path] = set()
|
||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._reload_callback_registered: bool = False
|
||||
@@ -254,9 +266,9 @@ class EmojiManager:
|
||||
logger.info("启动表情包管理器")
|
||||
|
||||
def reload_runtime_config(self) -> None:
|
||||
"""响应配置热重载,唤醒维护循环以尽快应用最新配置。"""
|
||||
"""响应配置热重载,重置维护循环等待时间以应用最新配置。"""
|
||||
self._maintenance_wakeup_event.set()
|
||||
logger.info("[配置热重载] Emoji 模块配置已更新,将立即应用到维护循环")
|
||||
logger.info("[配置热重载] Emoji 模块配置已更新,将按新的检查间隔等待后执行维护")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""清理 EmojiManager 生命周期资源。"""
|
||||
@@ -948,33 +960,73 @@ class EmojiManager:
|
||||
|
||||
def check_emoji_file_integrity(self) -> None:
|
||||
"""
|
||||
检查表情包完整性,删除文件缺失的表情包记录
|
||||
检查表情包文件和数据库注册记录的一致性。
|
||||
|
||||
数据库记录存在但文件缺失时删除数据库记录;文件存在但没有数据库记录时删除文件。
|
||||
"""
|
||||
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
||||
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
|
||||
removal_count = 0
|
||||
for emoji in self.emojis:
|
||||
if not emoji.full_path.exists():
|
||||
logger.warning(f"[完整性检查] 表情包文件缺失,准备修改记录: {emoji.file_name}")
|
||||
to_delete_emojis.append((emoji, False))
|
||||
if not emoji.description:
|
||||
logger.warning(f"[完整性检查] 表情包记录缺失描述,准备删除记录: {emoji.file_name}")
|
||||
to_delete_emojis.append((emoji, True))
|
||||
_ensure_directories()
|
||||
logger.info("[完整性检查] 开始检查表情包文件和注册记录一致性...")
|
||||
tracked_paths: set[Path] = set()
|
||||
record_removal_count = 0
|
||||
file_removal_count = 0
|
||||
available_emojis: list[MaiEmoji] = []
|
||||
|
||||
for emoji, is_description_empty in to_delete_emojis:
|
||||
if self.delete_emoji(emoji, is_description_empty):
|
||||
self.emojis.remove(emoji)
|
||||
self._emoji_num -= 1
|
||||
removal_count += 1
|
||||
logger.info(f"[完整性检查] 成功删除缺失文件的表情包记录: {emoji.file_name}")
|
||||
else:
|
||||
logger.error(f"[完整性检查] 删除缺失文件的表情包记录失败: {emoji.file_name}")
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_type=ImageType.EMOJI)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
record_path = _resolve_existing_emoji_path(record.full_path)
|
||||
if record.no_file_flag or record_path is None:
|
||||
logger.warning(
|
||||
f"[完整性检查] 表情包数据库记录缺少实际文件,删除数据库记录: id={record.id}, path={record.full_path}"
|
||||
)
|
||||
session.delete(record)
|
||||
record_removal_count += 1
|
||||
continue
|
||||
if not record.is_registered or record.is_banned:
|
||||
tracked_paths.add(record_path)
|
||||
continue
|
||||
try:
|
||||
available_emojis.append(MaiEmoji.from_db_instance(record))
|
||||
tracked_paths.add(record_path)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[完整性检查] 加载表情包记录时出错,将删除异常记录: {exc}\n记录ID: {record.id}, 路径: {record.full_path}"
|
||||
)
|
||||
session.delete(record)
|
||||
record_removal_count += 1
|
||||
|
||||
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
resolved_file = emoji_file.absolute().resolve()
|
||||
if resolved_file in tracked_paths:
|
||||
continue
|
||||
try:
|
||||
emoji_file.unlink()
|
||||
file_removal_count += 1
|
||||
logger.warning(f"[完整性检查] 表情包文件缺少数据库记录,删除文件: {emoji_file}")
|
||||
except Exception as exc:
|
||||
logger.error(f"[完整性检查] 删除无注册记录的表情包文件失败: {emoji_file}, error={exc}")
|
||||
|
||||
self.emojis = available_emojis
|
||||
self._known_emoji_file_paths = tracked_paths
|
||||
self._emoji_num = len(self.emojis)
|
||||
logger.info(
|
||||
f"[完整性检查] 表情包完整性检查完成,删除数据库记录 {record_removal_count} 条,删除图片文件 {file_removal_count} 个"
|
||||
)
|
||||
|
||||
async def periodic_emoji_maintenance(self) -> None:
|
||||
"""Run emoji maintenance tasks periodically."""
|
||||
while True:
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
self._maintenance_wakeup_event.clear()
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
_ensure_directories()
|
||||
try:
|
||||
self.check_emoji_file_integrity()
|
||||
@@ -989,24 +1041,21 @@ class EmojiManager:
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
resolved_file = emoji_file.absolute().resolve()
|
||||
if resolved_file in self._known_emoji_file_paths:
|
||||
continue
|
||||
try:
|
||||
register_status = await self.register_emoji_by_filename(emoji_file)
|
||||
except Exception as e:
|
||||
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
|
||||
register_status = "failed"
|
||||
if register_status == "registered":
|
||||
self._known_emoji_file_paths.add(resolved_file)
|
||||
break
|
||||
if register_status == "skipped":
|
||||
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
|
||||
else:
|
||||
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
|
||||
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
|
||||
|
||||
@@ -39,15 +39,12 @@ class ConfigSchemaGenerator:
|
||||
ui_parent = getattr(config_class, "__ui_parent__", "")
|
||||
ui_label = getattr(config_class, "__ui_label__", "")
|
||||
ui_icon = getattr(config_class, "__ui_icon__", "")
|
||||
ui_merge_children = getattr(config_class, "__ui_merge_children__", [])
|
||||
if ui_parent:
|
||||
schema["uiParent"] = ui_parent
|
||||
if ui_label:
|
||||
schema["uiLabel"] = ui_label
|
||||
if ui_icon:
|
||||
schema["uiIcon"] = ui_icon
|
||||
if ui_merge_children:
|
||||
schema["uiMergeChildren"] = list(ui_merge_children)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
@@ -28,6 +28,7 @@ class ExpressionResponse(BaseModel):
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
@@ -90,7 +91,61 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
|
||||
"""从最近消息中解析聊天显示名称。"""
|
||||
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.session_id) == chat_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
message = db_session.exec(statement).first()
|
||||
if not message:
|
||||
return None
|
||||
if message.group_id:
|
||||
return message.group_name or f"群聊{message.group_id}"
|
||||
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
|
||||
|
||||
|
||||
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
|
||||
"""从会话记录推断兜底显示名称。"""
|
||||
|
||||
if chat_session.group_id:
|
||||
return f"群聊{chat_session.group_id}"
|
||||
if chat_session.user_id:
|
||||
return f"用户{chat_session.user_id}"
|
||||
return chat_session.session_id
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
|
||||
try:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
return name
|
||||
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
|
||||
return name
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if session:
|
||||
if session.group_id:
|
||||
return f"群聊{session.group_id}"
|
||||
if session.user_id:
|
||||
return f"用户{session.user_id}"
|
||||
return chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
|
||||
"""将表达方式模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
@@ -101,38 +156,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
chat_id = expression.session_id or ""
|
||||
return ExpressionResponse(
|
||||
id=expression.id if expression.id is not None else 0,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=expression.session_id or "",
|
||||
chat_id=chat_id,
|
||||
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
|
||||
create_date=create_date,
|
||||
checked=False,
|
||||
rejected=False,
|
||||
modified_by=None,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by.value if expression.modified_by else None,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
try:
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
return chat_id
|
||||
name = _chat_manager.get_session_name(chat_id)
|
||||
return name or chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称。
|
||||
|
||||
@@ -145,8 +183,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
for chat_id in chat_ids:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
result[chat_id] = name
|
||||
result[chat_id] = get_chat_name(chat_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
@@ -176,19 +213,43 @@ async def get_chat_list() -> ChatListResponse:
|
||||
ChatListResponse: 可用于下拉选择的聊天列表。
|
||||
"""
|
||||
try:
|
||||
chat_list = []
|
||||
chat_by_id: Dict[str, ChatInfo] = {}
|
||||
for session_id, session in _chat_manager.sessions.items():
|
||||
chat_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
for chat_session in session.exec(select(ChatSession)).all():
|
||||
if chat_session.session_id in chat_by_id:
|
||||
continue
|
||||
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
|
||||
chat_by_id[chat_session.session_id] = ChatInfo(
|
||||
chat_id=chat_session.session_id,
|
||||
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
|
||||
platform=chat_session.platform,
|
||||
is_group=bool(chat_session.group_id),
|
||||
)
|
||||
|
||||
expression_chat_ids = {
|
||||
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
|
||||
}
|
||||
for session_id in expression_chat_ids:
|
||||
if session_id in chat_by_id:
|
||||
continue
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=get_chat_name(session_id, session),
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list = list(chat_by_id.values())
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
@@ -252,7 +313,7 @@ async def get_expression_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -281,7 +342,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=data)
|
||||
|
||||
@@ -321,7 +382,7 @@ async def create_expression(
|
||||
session.add(expression)
|
||||
session.flush()
|
||||
expression_id = expression.id
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
|
||||
|
||||
@@ -375,7 +436,7 @@ async def update_expression(
|
||||
db_expression.session_id = update_data["session_id"]
|
||||
db_expression.last_active_time = update_data["last_active_time"]
|
||||
session.add(db_expression)
|
||||
data = expression_to_response(db_expression)
|
||||
data = expression_to_response(db_expression, session)
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
@@ -524,6 +585,22 @@ class ReviewStatsResponse(BaseModel):
|
||||
user_checked: int
|
||||
|
||||
|
||||
def apply_review_filter(statement: Any, filter_type: str) -> Any:
|
||||
"""按审核状态过滤表达方式查询。"""
|
||||
if filter_type == "unchecked":
|
||||
return statement.where(col(Expression.checked).is_(False))
|
||||
if filter_type == "passed":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
|
||||
if filter_type == "rejected":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
|
||||
return statement
|
||||
|
||||
|
||||
def count_expressions(session: Any, statement: Any) -> int:
|
||||
"""统计表达方式查询结果数量。"""
|
||||
return len(session.exec(statement).all())
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""获取审核统计数据。
|
||||
@@ -533,12 +610,24 @@ async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
passed = 0
|
||||
rejected = 0
|
||||
ai_checked = 0
|
||||
user_checked = 0
|
||||
total = count_expressions(session, select(Expression.id))
|
||||
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
|
||||
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
|
||||
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
|
||||
ai_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.AI,
|
||||
),
|
||||
)
|
||||
user_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.USER,
|
||||
),
|
||||
)
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
@@ -587,10 +676,7 @@ async def get_review_list(
|
||||
ReviewListResponse: 审核列表响应。
|
||||
"""
|
||||
try:
|
||||
statement = select(Expression)
|
||||
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
statement = statement.where(col(Expression.id) == -1)
|
||||
statement = apply_review_filter(select(Expression), filter_type)
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
@@ -615,9 +701,7 @@ async def get_review_list(
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||
count_statement = apply_review_filter(select(Expression.id), filter_type)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
@@ -625,7 +709,7 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
@@ -706,10 +790,10 @@ async def batch_review_expressions(
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked:
|
||||
# 冲突检测:未审核列表发起的操作只允许处理仍处于未审核状态的条目。
|
||||
if item.require_unchecked and expression.checked:
|
||||
results.append(
|
||||
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||
BatchReviewResultItem(id=item.id, success=False, message="该表达方式已被审核,请刷新列表后重试")
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
@@ -727,6 +811,9 @@ async def batch_review_expressions(
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
db_expression.checked = True
|
||||
db_expression.rejected = item.rejected
|
||||
db_expression.modified_by = ModifiedBy.USER
|
||||
db_expression.last_active_time = datetime.now()
|
||||
session.add(db_expression)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user