merge: 同步上游 dev 最新内容
This commit is contained in:
@@ -78,7 +78,7 @@ class ChatManager:
|
||||
"""初始化聊天管理器"""
|
||||
try:
|
||||
await self.load_all_sessions_from_db()
|
||||
logger.info(f"已加载 {len(self.sessions)} 个会话记录到内存中")
|
||||
logger.debug(f"已加载 {len(self.sessions)} 个会话记录到内存中")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
@@ -103,6 +104,24 @@ class BaseMaisakaReplyGenerator:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦。\n是人类。"
|
||||
|
||||
@staticmethod
|
||||
def _select_reply_style() -> str:
|
||||
"""按配置概率选择本次 replyer 使用的表达风格。"""
|
||||
personality_config = global_config.personality
|
||||
reply_style = personality_config.reply_style
|
||||
candidate_styles = [style.strip() for style in personality_config.multiple_reply_style if style.strip()]
|
||||
|
||||
if not candidate_styles:
|
||||
return reply_style
|
||||
|
||||
probability = personality_config.multiple_probability
|
||||
if probability <= 0:
|
||||
return reply_style
|
||||
if random.random() > probability:
|
||||
return reply_style
|
||||
|
||||
return random.choice(candidate_styles)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
@@ -293,7 +312,7 @@ class BaseMaisakaReplyGenerator:
|
||||
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
|
||||
replyer_at_block=self._build_replyer_at_block(),
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
reply_style=self._select_reply_style(),
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
@@ -8,8 +8,6 @@ import random
|
||||
import re
|
||||
import time
|
||||
|
||||
import jieba
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
@@ -912,110 +910,3 @@ def parse_keywords_string(keywords_input) -> list[str]:
|
||||
return [keywords_str] if keywords_str else []
|
||||
|
||||
|
||||
def cut_key_words(concept_name: str) -> list[str]:
|
||||
"""对概念名称进行jieba分词,并过滤掉关键词列表中的关键词"""
|
||||
concept_name_tokens = list(jieba.cut(concept_name))
|
||||
|
||||
# 定义常见连词、停用词与标点
|
||||
conjunctions = {"和", "与", "及", "跟", "以及", "并且", "而且", "或", "或者", "并"}
|
||||
stop_words = {
|
||||
"的",
|
||||
"了",
|
||||
"呢",
|
||||
"吗",
|
||||
"吧",
|
||||
"啊",
|
||||
"哦",
|
||||
"恩",
|
||||
"嗯",
|
||||
"呀",
|
||||
"嘛",
|
||||
"哇",
|
||||
"在",
|
||||
"是",
|
||||
"很",
|
||||
"也",
|
||||
"又",
|
||||
"就",
|
||||
"都",
|
||||
"还",
|
||||
"更",
|
||||
"最",
|
||||
"被",
|
||||
"把",
|
||||
"给",
|
||||
"对",
|
||||
"和",
|
||||
"与",
|
||||
"及",
|
||||
"跟",
|
||||
"并",
|
||||
"而且",
|
||||
"或者",
|
||||
"或",
|
||||
"以及",
|
||||
}
|
||||
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
|
||||
|
||||
# 清理空白并初步过滤纯标点
|
||||
cleaned_tokens = []
|
||||
for tok in concept_name_tokens:
|
||||
t = tok.strip()
|
||||
if not t:
|
||||
continue
|
||||
# 去除纯标点
|
||||
if all(ch in chinese_punctuations for ch in t):
|
||||
continue
|
||||
cleaned_tokens.append(t)
|
||||
|
||||
# 合并连词两侧的词(仅当两侧都存在且不是标点/停用词时)
|
||||
merged_tokens = []
|
||||
i = 0
|
||||
n = len(cleaned_tokens)
|
||||
while i < n:
|
||||
tok = cleaned_tokens[i]
|
||||
if tok in conjunctions and merged_tokens and i + 1 < n:
|
||||
left = merged_tokens[-1]
|
||||
right = cleaned_tokens[i + 1]
|
||||
# 左右都需要是有效词
|
||||
if (
|
||||
left
|
||||
and right
|
||||
and left not in conjunctions
|
||||
and right not in conjunctions
|
||||
and left not in stop_words
|
||||
and right not in stop_words
|
||||
and not all(ch in chinese_punctuations for ch in left)
|
||||
and not all(ch in chinese_punctuations for ch in right)
|
||||
):
|
||||
# 合并为一个新词,并替换掉左侧与跳过右侧
|
||||
combined = f"{left}{tok}{right}"
|
||||
merged_tokens[-1] = combined
|
||||
i += 2
|
||||
continue
|
||||
# 常规推进
|
||||
merged_tokens.append(tok)
|
||||
i += 1
|
||||
|
||||
# 二次过滤:去除停用词、单字符纯标点与无意义项
|
||||
result_tokens = []
|
||||
seen = set()
|
||||
# ban_words = set(getattr(global_config.memory, "memory_ban_words", []) or [])
|
||||
for tok in merged_tokens:
|
||||
if tok in conjunctions:
|
||||
# 独立连词丢弃
|
||||
continue
|
||||
if tok in stop_words:
|
||||
continue
|
||||
# if tok in ban_words:
|
||||
# continue
|
||||
if all(ch in chinese_punctuations for ch in tok):
|
||||
continue
|
||||
if tok.strip() == "":
|
||||
continue
|
||||
if tok not in seen:
|
||||
seen.add(tok)
|
||||
result_tokens.append(tok)
|
||||
|
||||
filtered_concept_name_tokens = result_tokens
|
||||
return filtered_concept_name_tokens
|
||||
|
||||
@@ -79,7 +79,6 @@ class BufferCLI:
|
||||
)
|
||||
message.raw_message = MessageSequence([TextComponent(text=user_text)])
|
||||
message.processed_plain_text = user_text
|
||||
message.display_message = user_text
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
@@ -59,7 +59,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
self.reply_to: Optional[str] = None
|
||||
|
||||
self.processed_plain_text: Optional[str] = None
|
||||
self.display_message: Optional[str] = None
|
||||
self.raw_message: MessageSequence
|
||||
|
||||
@classmethod
|
||||
@@ -86,7 +85,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
obj.reply_to = db_record.reply_to
|
||||
obj.session_id = db_record.session_id
|
||||
obj.processed_plain_text = db_record.processed_plain_text
|
||||
obj.display_message = db_record.display_message
|
||||
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content)
|
||||
return obj
|
||||
|
||||
@@ -113,7 +111,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||
is_notify=self.is_notify,
|
||||
raw_content=MessageUtils.from_MaiSeq_to_db_record_msg(self.raw_message),
|
||||
processed_plain_text=self.processed_plain_text,
|
||||
display_message=self.display_message,
|
||||
additional_config=additional_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -51,7 +51,6 @@ class Messages(SQLModel, table=True):
|
||||
# 消息内容
|
||||
raw_content: bytes = Field(sa_column=Column(LargeBinary)) # msgpack后的原始消息内容
|
||||
processed_plain_text: Optional[str] = Field(default=None) # 平面化处理后的纯文本消息
|
||||
display_message: Optional[str] = Field(default=None) # 显示的消息内容(被放入Prompt)
|
||||
|
||||
# 其他配置
|
||||
additional_config: Optional[str] = Field(default=None) # 额外配置,JSON格式存储
|
||||
|
||||
@@ -8,12 +8,14 @@ from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .v2_to_v3 import migrate_v2_to_v3
|
||||
from .v3_to_v4 import migrate_v3_to_v4
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
EMPTY_SCHEMA_VERSION = 0
|
||||
LEGACY_V1_SCHEMA_VERSION = 1
|
||||
V2_SCHEMA_VERSION = 2
|
||||
LATEST_SCHEMA_VERSION = 3
|
||||
V3_SCHEMA_VERSION = 3
|
||||
LATEST_SCHEMA_VERSION = 4
|
||||
|
||||
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||
"chat_streams",
|
||||
@@ -78,9 +80,46 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
if snapshot.has_column("mai_messages", "display_message"):
|
||||
return None
|
||||
return LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
class V3SchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
"""v3 schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "v3_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否为 v3 结构。"""
|
||||
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return None
|
||||
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
|
||||
return None
|
||||
if snapshot.has_table("action_records"):
|
||||
return None
|
||||
if snapshot.has_table("thinking_questions"):
|
||||
return None
|
||||
if snapshot.has_column("images", "emotion"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_hash"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "full_path"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_type"):
|
||||
return None
|
||||
if not snapshot.has_column("chat_history", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
if not snapshot.has_column("mai_messages", "display_message"):
|
||||
return None
|
||||
return V3_SCHEMA_VERSION
|
||||
|
||||
|
||||
class V2SchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
"""v2 schema 结构探测器。"""
|
||||
|
||||
@@ -174,6 +213,7 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
||||
|
||||
return [
|
||||
LatestSchemaVersionDetector(),
|
||||
V3SchemaVersionDetector(),
|
||||
V2SchemaVersionDetector(),
|
||||
LegacyV1SchemaDetector(),
|
||||
]
|
||||
@@ -211,10 +251,17 @@ def build_default_migration_registry() -> MigrationRegistry:
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=V2_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
version_to=V3_SCHEMA_VERSION,
|
||||
name="v2_to_v3",
|
||||
description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。",
|
||||
handler=migrate_v2_to_v3,
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=V3_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
name="v3_to_v4",
|
||||
description="移除 mai_messages.display_message 弃用列。",
|
||||
handler=migrate_v3_to_v4,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -489,9 +489,6 @@ def _build_legacy_message_additional_config(row: Mapping[str, Any]) -> Optional[
|
||||
|
||||
legacy_fields = {
|
||||
"intercept_message_level": row.get("intercept_message_level"),
|
||||
"interest_value": row.get("interest_value"),
|
||||
"key_words": row.get("key_words"),
|
||||
"key_words_lite": row.get("key_words_lite"),
|
||||
"priority_info": row.get("priority_info"),
|
||||
"priority_mode": row.get("priority_mode"),
|
||||
"selected_expressions": row.get("selected_expressions"),
|
||||
|
||||
155
src/common/database/migrations/v3_to_v4.py
Normal file
155
src/common/database/migrations/v3_to_v4.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""v3 schema 升级到 v4 的迁移逻辑。"""
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .models import MigrationExecutionContext
|
||||
from .schema import SQLiteSchemaInspector
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
_V3_MESSAGES_BACKUP_TABLE = "__v3_mai_messages_backup"
|
||||
_V4_MESSAGES_CREATE_SQL = """
|
||||
CREATE TABLE mai_messages (
|
||||
id INTEGER NOT NULL,
|
||||
message_id VARCHAR(255) NOT NULL,
|
||||
timestamp DATETIME,
|
||||
platform VARCHAR(100) NOT NULL,
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
user_nickname VARCHAR(255) NOT NULL,
|
||||
user_cardname VARCHAR(255),
|
||||
group_id VARCHAR(255),
|
||||
group_name VARCHAR(255),
|
||||
is_mentioned BOOLEAN NOT NULL,
|
||||
is_at BOOLEAN NOT NULL,
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
reply_to VARCHAR(255),
|
||||
is_emoji BOOLEAN NOT NULL,
|
||||
is_picture BOOLEAN NOT NULL,
|
||||
is_command BOOLEAN NOT NULL,
|
||||
is_notify BOOLEAN NOT NULL,
|
||||
raw_content BLOB,
|
||||
processed_plain_text VARCHAR,
|
||||
additional_config VARCHAR,
|
||||
PRIMARY KEY (id)
|
||||
)
|
||||
"""
|
||||
_V4_MESSAGES_INDEX_STATEMENTS = (
|
||||
"CREATE INDEX ix_mai_messages_group_id ON mai_messages (group_id)",
|
||||
"CREATE INDEX ix_mai_messages_message_id ON mai_messages (message_id)",
|
||||
"CREATE INDEX ix_mai_messages_platform ON mai_messages (platform)",
|
||||
"CREATE INDEX ix_mai_messages_session_id ON mai_messages (session_id)",
|
||||
"CREATE INDEX ix_mai_messages_user_id ON mai_messages (user_id)",
|
||||
"CREATE INDEX ix_mai_messages_user_nickname ON mai_messages (user_nickname)",
|
||||
)
|
||||
|
||||
|
||||
def migrate_v3_to_v4(context: MigrationExecutionContext) -> None:
|
||||
"""执行 v3 到 v4 的 schema 迁移。"""
|
||||
|
||||
connection = context.connection
|
||||
total_records = _count_table_rows(connection, "mai_messages")
|
||||
context.start_progress(
|
||||
total_tables=1,
|
||||
total_records=total_records,
|
||||
description="v3 -> v4 迁移进度",
|
||||
table_unit_name="表",
|
||||
record_unit_name="记录",
|
||||
)
|
||||
|
||||
migrated_message_rows = _migrate_messages_table_to_v4(connection)
|
||||
context.advance_progress(
|
||||
records=migrated_message_rows,
|
||||
completed_tables=1,
|
||||
item_name="mai_messages",
|
||||
)
|
||||
|
||||
logger.info(f"v3 -> v4 数据库迁移完成: mai_messages重建={migrated_message_rows}")
|
||||
|
||||
|
||||
def _count_table_rows(connection: Connection, table_name: str) -> int:
|
||||
"""统计表记录数,不存在时返回 0。"""
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
if not schema_inspector.table_exists(connection, table_name):
|
||||
return 0
|
||||
row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first()
|
||||
return int(row[0]) if row else 0
|
||||
|
||||
|
||||
def _migrate_messages_table_to_v4(connection: Connection) -> int:
|
||||
"""重建 ``mai_messages`` 表并移除弃用的 ``display_message`` 列。"""
|
||||
|
||||
schema_inspector = SQLiteSchemaInspector()
|
||||
if not schema_inspector.table_exists(connection, "mai_messages"):
|
||||
return 0
|
||||
if not schema_inspector.get_table_schema(connection, "mai_messages").has_column("display_message"):
|
||||
return _count_table_rows(connection, "mai_messages")
|
||||
if schema_inspector.table_exists(connection, _V3_MESSAGES_BACKUP_TABLE):
|
||||
raise DatabaseMigrationExecutionError(
|
||||
f"检测到残留备份表 {_V3_MESSAGES_BACKUP_TABLE},无法安全执行 v3 -> v4 mai_messages 迁移。"
|
||||
)
|
||||
|
||||
connection.exec_driver_sql(f'ALTER TABLE "mai_messages" RENAME TO "{_V3_MESSAGES_BACKUP_TABLE}"')
|
||||
connection.exec_driver_sql(_V4_MESSAGES_CREATE_SQL)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
f"""
|
||||
INSERT INTO mai_messages (
|
||||
id,
|
||||
message_id,
|
||||
timestamp,
|
||||
platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
user_cardname,
|
||||
group_id,
|
||||
group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
session_id,
|
||||
reply_to,
|
||||
is_emoji,
|
||||
is_picture,
|
||||
is_command,
|
||||
is_notify,
|
||||
raw_content,
|
||||
processed_plain_text,
|
||||
additional_config
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
message_id,
|
||||
timestamp,
|
||||
platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
user_cardname,
|
||||
group_id,
|
||||
group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
session_id,
|
||||
reply_to,
|
||||
is_emoji,
|
||||
is_picture,
|
||||
is_command,
|
||||
is_notify,
|
||||
raw_content,
|
||||
COALESCE(NULLIF(processed_plain_text, ''), display_message),
|
||||
additional_config
|
||||
FROM "{_V3_MESSAGES_BACKUP_TABLE}"
|
||||
ORDER BY id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
migrated_rows = _count_table_rows(connection, "mai_messages")
|
||||
connection.exec_driver_sql(f'DROP TABLE "{_V3_MESSAGES_BACKUP_TABLE}"')
|
||||
for statement in _V4_MESSAGES_INDEX_STATEMENTS:
|
||||
connection.exec_driver_sql(statement)
|
||||
return migrated_rows
|
||||
@@ -829,20 +829,19 @@ def initialize_logging(verbose: bool = True):
|
||||
reconfigure_existing_loggers()
|
||||
|
||||
# 启动日志清理任务
|
||||
start_log_cleanup_task(verbose=verbose)
|
||||
start_log_cleanup_task()
|
||||
|
||||
# 只在 verbose=True 时输出详细的初始化信息
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
|
||||
|
||||
logger.info("日志系统已初始化:")
|
||||
logger.info(f" - 控制台级别: {console_level}")
|
||||
logger.info(f" - 文件级别: {file_level}")
|
||||
max_log_files = max(1, int(LOG_CONFIG.get("max_log_files", 30) or 30))
|
||||
log_cleanup_days = max(1, int(LOG_CONFIG.get("log_cleanup_days", 30) or 30))
|
||||
logger.info(f" - 轮转份数: {max_log_files}个文件|自动清理: {log_cleanup_days}天前的日志")
|
||||
logger.info(
|
||||
f"日志系统已初始化:控制台={console_level},文件={file_level},"
|
||||
f"轮转={max_log_files}个文件,清理={log_cleanup_days}天前"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_old_logs():
|
||||
@@ -875,12 +874,8 @@ def cleanup_old_logs():
|
||||
logger.error(f"清理旧日志文件时出错: {e}")
|
||||
|
||||
|
||||
def start_log_cleanup_task(verbose: bool = True):
|
||||
"""启动日志清理任务
|
||||
|
||||
Args:
|
||||
verbose: 是否输出启动信息。默认为 True。
|
||||
"""
|
||||
def start_log_cleanup_task():
|
||||
"""启动日志清理任务"""
|
||||
global _cleanup_task_started
|
||||
|
||||
# 防止重复启动清理任务
|
||||
@@ -897,12 +892,6 @@ def start_log_cleanup_task(verbose: bool = True):
|
||||
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
if verbose:
|
||||
logger = get_logger("logger")
|
||||
max_log_files = max(1, int(LOG_CONFIG.get("max_log_files", 30) or 30))
|
||||
log_cleanup_days = max(1, int(LOG_CONFIG.get("log_cleanup_days", 30) or 30))
|
||||
logger.info(f"已启动日志清理任务,将自动清理{log_cleanup_days}天前的日志文件(轮转份数限制: {max_log_files}个文件)")
|
||||
|
||||
|
||||
def shutdown_logging():
|
||||
"""优雅关闭日志系统,释放所有文件句柄"""
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from tomlkit import parse as parse_toml
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -22,6 +26,19 @@ STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
|
||||
extract_prompt_placeholders = extract_placeholders
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptMetadata:
|
||||
display_name: str = ""
|
||||
advanced: bool = False
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptTemplateInfo:
|
||||
path: Path
|
||||
metadata: PromptMetadata
|
||||
|
||||
|
||||
def get_prompts_root(prompts_root: Path | None = None) -> Path:
|
||||
return (prompts_root or PROMPTS_ROOT).resolve()
|
||||
|
||||
@@ -80,24 +97,86 @@ def iter_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
|
||||
|
||||
|
||||
def _raise_duplicate_prompt_name(name: str, first_path: Path, second_path: Path, prompts_root: Path) -> None:
|
||||
path_a = first_path.relative_to(prompts_root).as_posix()
|
||||
path_b = second_path.relative_to(prompts_root).as_posix()
|
||||
raise ValueError(
|
||||
t(
|
||||
"prompt.duplicate_template_name",
|
||||
name=name,
|
||||
path_a=first_path.relative_to(prompts_root),
|
||||
path_b=second_path.relative_to(prompts_root),
|
||||
path_a=path_a,
|
||||
path_b=path_b,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
|
||||
prompt_paths: dict[str, Path] = {}
|
||||
def _coerce_metadata(raw_metadata: Any) -> PromptMetadata:
|
||||
if not isinstance(raw_metadata, dict):
|
||||
return PromptMetadata()
|
||||
|
||||
display_name = raw_metadata.get("display_name", "")
|
||||
advanced = raw_metadata.get("advanced", False)
|
||||
description = raw_metadata.get("description", "")
|
||||
|
||||
return PromptMetadata(
|
||||
display_name=display_name if isinstance(display_name, str) else "",
|
||||
advanced=advanced if isinstance(advanced, bool) else False,
|
||||
description=description if isinstance(description, str) else "",
|
||||
)
|
||||
|
||||
|
||||
def _read_metadata_file(metadata_path: Path) -> dict[str, Any]:
|
||||
if not metadata_path.is_file():
|
||||
return {}
|
||||
|
||||
try:
|
||||
if metadata_path.suffix == ".json":
|
||||
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
|
||||
else:
|
||||
metadata = parse_toml(metadata_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
logger.warning("读取 Prompt 元信息文件 %s 失败:%s", metadata_path, exc)
|
||||
return {}
|
||||
|
||||
return dict(metadata) if isinstance(metadata, dict) else {}
|
||||
|
||||
|
||||
def _extract_template_metadata(metadata: dict[str, Any], prompt_name: str) -> dict[str, Any]:
|
||||
templates = metadata.get("templates")
|
||||
if isinstance(templates, dict) and isinstance(templates.get(prompt_name), dict):
|
||||
return dict(templates[prompt_name])
|
||||
|
||||
prompt_metadata = metadata.get(prompt_name)
|
||||
if isinstance(prompt_metadata, dict):
|
||||
return dict(prompt_metadata)
|
||||
|
||||
return metadata if any(key in metadata for key in ("display_name", "advanced", "description")) else {}
|
||||
|
||||
|
||||
def _load_prompt_metadata(prompt_path: Path) -> PromptMetadata:
|
||||
prompt_name = prompt_path.stem
|
||||
metadata_sources = (
|
||||
prompt_path.with_name(f"{prompt_name}.meta.toml"),
|
||||
prompt_path.with_name(f"{prompt_name}.meta.json"),
|
||||
prompt_path.parent / ".meta.toml",
|
||||
prompt_path.parent / ".meta.json",
|
||||
)
|
||||
|
||||
merged_metadata: dict[str, Any] = {}
|
||||
for metadata_path in reversed(metadata_sources):
|
||||
raw_metadata = _read_metadata_file(metadata_path)
|
||||
merged_metadata.update(_extract_template_metadata(raw_metadata, prompt_name))
|
||||
|
||||
return _coerce_metadata(merged_metadata)
|
||||
|
||||
|
||||
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, PromptTemplateInfo]:
|
||||
prompt_paths: dict[str, PromptTemplateInfo] = {}
|
||||
for prompt_path in iter_prompt_files(directory):
|
||||
prompt_name = prompt_path.stem
|
||||
existing_path = prompt_paths.get(prompt_name)
|
||||
if existing_path is not None:
|
||||
_raise_duplicate_prompt_name(prompt_name, existing_path, prompt_path, prompts_root)
|
||||
prompt_paths[prompt_name] = prompt_path
|
||||
existing_info = prompt_paths.get(prompt_name)
|
||||
if existing_info is not None:
|
||||
_raise_duplicate_prompt_name(prompt_name, existing_info.path, prompt_path, prompts_root)
|
||||
prompt_paths[prompt_name] = PromptTemplateInfo(path=prompt_path, metadata=_load_prompt_metadata(prompt_path))
|
||||
return prompt_paths
|
||||
|
||||
|
||||
@@ -115,11 +194,11 @@ def _iter_locale_candidates(requested_locale: str) -> list[str]:
|
||||
return locale_candidates
|
||||
|
||||
|
||||
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
|
||||
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, PromptTemplateInfo]:
|
||||
resolved_prompts_root = get_prompts_root(prompts_root)
|
||||
requested_locale = normalize_locale(locale or get_locale())
|
||||
|
||||
prompt_paths: dict[str, Path] = {}
|
||||
prompt_paths: dict[str, PromptTemplateInfo] = {}
|
||||
for directory in _iter_prompt_template_layers(resolved_prompts_root, requested_locale):
|
||||
prompt_paths.update(_scan_prompt_directory(directory, resolved_prompts_root))
|
||||
|
||||
@@ -149,7 +228,7 @@ def resolve_prompt_path(
|
||||
else:
|
||||
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
|
||||
if normalized_name in prompt_paths:
|
||||
return prompt_paths[normalized_name]
|
||||
return prompt_paths[normalized_name].path
|
||||
|
||||
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))
|
||||
|
||||
|
||||
@@ -56,9 +56,9 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0-pre.10"
|
||||
CONFIG_VERSION: str = "8.10.6"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.8"
|
||||
MMC_VERSION: str = "1.0.0-pre.11"
|
||||
CONFIG_VERSION: str = "8.10.7"
|
||||
MODEL_CONFIG_VERSION: str = "1.15.3"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
|
||||
@@ -135,7 +135,6 @@ class ConfigBase(BaseModel, AttrDocBase):
|
||||
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
|
||||
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
|
||||
__ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名)
|
||||
__ui_merge_children__: ClassVar[List[str]] = [] # 在 WebUI 中并入当前配置卡片展示的子配置字段名
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
||||
|
||||
@@ -343,10 +343,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
|
||||
reasons: list[str] = []
|
||||
|
||||
bot = _as_dict(data.get("bot"))
|
||||
if bot is not None and isinstance(bot.get("qq_account"), str) and not bot["qq_account"].strip():
|
||||
bot["qq_account"] = 0
|
||||
if bot is not None and isinstance(bot.get("qq_account"), int):
|
||||
bot["qq_account"] = str(bot["qq_account"]) if bot["qq_account"] > 0 else ""
|
||||
migrated_any = True
|
||||
reasons.append("bot.qq_account_empty")
|
||||
reasons.append("bot.qq_account_int_to_string")
|
||||
|
||||
chat = _as_dict(data.get("chat"))
|
||||
if chat is not None and _migrate_chat_talk_value_rules(chat):
|
||||
|
||||
@@ -351,6 +351,7 @@ class ModelInfo(ConfigBase):
|
||||
Gemini 客户端会按自身支持的字段筛选并映射到 GenerateContentConfig、EmbedContentConfig 或音频请求配置中。"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
self.model_identifier = self.model_identifier.strip()
|
||||
if not self.model_identifier:
|
||||
raise ValueError(t("config.model_identifier_empty_generic"))
|
||||
if not self.name:
|
||||
@@ -402,6 +403,7 @@ class TaskConfig(ConfigBase):
|
||||
"x-widget": "input",
|
||||
"x-icon": "alert-circle",
|
||||
"step": 0.1,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
@@ -420,15 +422,6 @@ class TaskConfig(ConfigBase):
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "wrench",
|
||||
},
|
||||
)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
replyer: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
@@ -436,17 +429,7 @@ class ModelTaskConfig(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""首要回复模型配置"""
|
||||
|
||||
learner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "graduation-cap",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
|
||||
"""回复模型配置"""
|
||||
|
||||
planner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
@@ -457,6 +440,25 @@ class ModelTaskConfig(ConfigBase):
|
||||
)
|
||||
"""规划模型配置"""
|
||||
|
||||
utils: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "wrench",
|
||||
},
|
||||
)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
learner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "graduation-cap",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
|
||||
|
||||
vlm: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
@@ -471,6 +473,7 @@ class ModelTaskConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "volume-2",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""语音识别模型配置"""
|
||||
|
||||
@@ -4,6 +4,17 @@ import re
|
||||
|
||||
from .config_base import ConfigBase, Field
|
||||
|
||||
RULE_TYPE_OPTION_DESCRIPTIONS = {
|
||||
"group": "群聊聊天流,item_id 填群号或群聊 ID",
|
||||
"private": "私聊聊天流,item_id 填用户 ID",
|
||||
}
|
||||
|
||||
VISUAL_MODE_OPTION_DESCRIPTIONS = {
|
||||
"auto": "根据模型信息自动选择文本或多模态模式",
|
||||
"text": "纯文本模式,不向模型发送视觉输入",
|
||||
"multimodal": "多模态模式,会向模型发送视觉输入",
|
||||
}
|
||||
|
||||
"""
|
||||
须知:
|
||||
1. 本文件中记录了所有的配置项
|
||||
@@ -19,7 +30,7 @@ class ExampleConfig(ConfigBase):
|
||||
class BotConfig(ConfigBase):
|
||||
"""机器人配置类"""
|
||||
|
||||
__ui_label__ = "基本信息"
|
||||
__ui_label__ = "基础"
|
||||
__ui_icon__ = "bot"
|
||||
|
||||
platform: str = Field(
|
||||
@@ -29,17 +40,19 @@ class BotConfig(ConfigBase):
|
||||
"x-icon": "wifi",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
"x-row": "bot-platform-account",
|
||||
},
|
||||
)
|
||||
"""平台"""
|
||||
|
||||
qq_account: int = Field(
|
||||
default=0,
|
||||
qq_account: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "user",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
"x-row": "bot-platform-account",
|
||||
},
|
||||
)
|
||||
"""QQ账号"""
|
||||
@@ -76,6 +89,7 @@ class BotConfig(ConfigBase):
|
||||
class PersonalityConfig(ConfigBase):
|
||||
"""人格配置类"""
|
||||
|
||||
__ui_parent__ = "bot"
|
||||
__ui_label__ = "人格"
|
||||
__ui_icon__ = "user-circle"
|
||||
|
||||
@@ -84,6 +98,8 @@ class PersonalityConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user-circle",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
},
|
||||
)
|
||||
"""人格,建议200字以内,描述人格特质和身份特征;可以写完整设定。要求第二人称"""
|
||||
@@ -93,6 +109,8 @@ class PersonalityConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "message-square",
|
||||
"x-textarea-min-height": 40,
|
||||
"x-textarea-rows": 1,
|
||||
},
|
||||
)
|
||||
"""默认表达风格,描述麦麦说话的表达风格,表达习惯,如要修改,可以酌情新增内容,建议1-2行"""
|
||||
@@ -137,6 +155,8 @@ class VisualConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
"x-option-descriptions": VISUAL_MODE_OPTION_DESCRIPTIONS,
|
||||
"x-row": "visual-modes",
|
||||
},
|
||||
)
|
||||
"""规划器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
@@ -146,6 +166,8 @@ class VisualConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "git-branch",
|
||||
"x-option-descriptions": VISUAL_MODE_OPTION_DESCRIPTIONS,
|
||||
"x-row": "visual-modes",
|
||||
},
|
||||
)
|
||||
"""回复器模式,auto根据模型信息自动选择,text为纯文本模式,multimodal为多模态模式"""
|
||||
@@ -158,7 +180,13 @@ class TalkRulesItem(ConfigBase):
|
||||
item_id: str = ""
|
||||
"""用户ID,与平台一起留空表示全局"""
|
||||
|
||||
rule_type: Literal["group", "private"] = "group"
|
||||
rule_type: Literal["group", "private"] = Field(
|
||||
default="group",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
|
||||
},
|
||||
)
|
||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||
|
||||
time: str = ""
|
||||
@@ -181,6 +209,7 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "message-circle",
|
||||
"x-row": "talk-values",
|
||||
"step": 0.1,
|
||||
},
|
||||
)
|
||||
@@ -193,6 +222,7 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "message-circle",
|
||||
"x-row": "talk-values",
|
||||
"step": 0.1,
|
||||
},
|
||||
)
|
||||
@@ -203,11 +233,19 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "at-sign",
|
||||
"x-row": "reply-switches",
|
||||
},
|
||||
)
|
||||
"""是否启用提及必回复"""
|
||||
|
||||
inevitable_at_reply: bool = Field(default=True)
|
||||
inevitable_at_reply: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "at-sign",
|
||||
"x-row": "reply-switches",
|
||||
},
|
||||
)
|
||||
"""是否启用at必回复"""
|
||||
|
||||
enable_at: bool = Field(
|
||||
@@ -235,6 +273,9 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
"x-row": "context-sizes",
|
||||
},
|
||||
)
|
||||
"""上下文长度"""
|
||||
@@ -244,6 +285,9 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
"x-row": "context-sizes",
|
||||
},
|
||||
)
|
||||
"""私聊上下文长度"""
|
||||
@@ -321,7 +365,8 @@ class ChatConfig(ConfigBase):
|
||||
class MessageReceiveConfig(ConfigBase):
|
||||
"""消息接收配置类"""
|
||||
|
||||
__ui_parent__ = "response_post_process"
|
||||
__ui_label__ = "消息接收"
|
||||
__ui_icon__ = "message-square-text"
|
||||
|
||||
image_parse_threshold: int = Field(
|
||||
default=5,
|
||||
@@ -386,6 +431,7 @@ class TargetItem(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "users",
|
||||
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
|
||||
},
|
||||
)
|
||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||
@@ -394,7 +440,7 @@ class TargetItem(ConfigBase):
|
||||
class MemoryConfig(ConfigBase):
|
||||
"""记忆配置类"""
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
__ui_parent__ = "a_memorix"
|
||||
|
||||
|
||||
global_memory: bool = Field(
|
||||
@@ -1040,6 +1086,7 @@ class LearningItem(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "users",
|
||||
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
|
||||
},
|
||||
)
|
||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||
@@ -1051,7 +1098,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""是否启用表达学习"""
|
||||
"""是否使用表达"""
|
||||
|
||||
enable_learning: bool = Field(
|
||||
default=True,
|
||||
@@ -1060,7 +1107,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "graduation-cap",
|
||||
},
|
||||
)
|
||||
"""是否启用表达优化学习"""
|
||||
"""是否学习表达"""
|
||||
|
||||
enable_jargon_learning: bool = Field(
|
||||
default=False,
|
||||
@@ -1069,7 +1116,7 @@ class LearningItem(ConfigBase):
|
||||
"x-icon": "book",
|
||||
},
|
||||
)
|
||||
"""是否启用jargon学习"""
|
||||
"""是否学习黑话"""
|
||||
|
||||
class ExpressionGroup(ConfigBase):
|
||||
"""表达互通组配置类,若列表为空代表全局共享"""
|
||||
@@ -1177,7 +1224,8 @@ class ExpressionConfig(ConfigBase):
|
||||
class VoiceConfig(ConfigBase):
|
||||
"""语音识别配置类"""
|
||||
|
||||
__ui_parent__ = "emoji"
|
||||
__ui_label__ = "语音"
|
||||
__ui_icon__ = "mic"
|
||||
|
||||
enable_asr: bool = Field(
|
||||
default=False,
|
||||
@@ -1192,8 +1240,8 @@ class VoiceConfig(ConfigBase):
|
||||
class EmojiConfig(ConfigBase):
|
||||
"""表情包配置类"""
|
||||
|
||||
__ui_label__ = "功能"
|
||||
__ui_icon__ = "puzzle"
|
||||
__ui_label__ = "表情包"
|
||||
__ui_icon__ = "smile"
|
||||
|
||||
emoji_send_num: int = Field(
|
||||
default=25,
|
||||
@@ -1254,16 +1302,6 @@ class EmojiConfig(ConfigBase):
|
||||
)
|
||||
"""是否启用表情包过滤,只有符合该要求的表情包才会被保存"""
|
||||
|
||||
filtration_prompt: str = Field(
|
||||
default="符合公序良俗",
|
||||
json_schema_extra={
|
||||
"advanced": True,
|
||||
"x-widget": "input",
|
||||
"x-icon": "shield",
|
||||
},
|
||||
)
|
||||
"""表情包过滤要求,只有符合该要求的表情包才会被保存"""
|
||||
|
||||
|
||||
class KeywordRuleConfig(ConfigBase):
|
||||
"""关键词规则配置类"""
|
||||
@@ -1314,7 +1352,7 @@ class KeywordRuleConfig(ConfigBase):
|
||||
class KeywordReactionConfig(ConfigBase):
|
||||
"""关键词配置类"""
|
||||
|
||||
__ui_parent__ = "response_post_process"
|
||||
__ui_parent__ = "message_receive"
|
||||
|
||||
keyword_rules: list[KeywordRuleConfig] = Field(
|
||||
default_factory=lambda: [],
|
||||
@@ -1345,9 +1383,8 @@ class KeywordReactionConfig(ConfigBase):
|
||||
class ResponsePostProcessConfig(ConfigBase):
|
||||
"""回复后处理配置类"""
|
||||
|
||||
__ui_label__ = "处理"
|
||||
__ui_label__ = "后处理"
|
||||
__ui_icon__ = "settings"
|
||||
__ui_merge_children__ = ["chinese_typo", "response_splitter"]
|
||||
|
||||
enable_response_post_process: bool = Field(
|
||||
default=True,
|
||||
@@ -1742,6 +1779,7 @@ class ExtraPromptItem(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "users",
|
||||
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
|
||||
},
|
||||
)
|
||||
"""聊天流类型,group(群聊)或private(私聊)"""
|
||||
@@ -2387,14 +2425,14 @@ class MCPServerItemConfig(ConfigBase):
|
||||
)
|
||||
"""是否启用当前 MCP 服务器"""
|
||||
|
||||
transport: Literal["stdio", "streamable_http"] = Field(
|
||||
transport: Literal["stdio", "streamable_http", "sse"] = Field(
|
||||
default="stdio",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "shuffle",
|
||||
},
|
||||
)
|
||||
"""传输方式,可选 `stdio` 或 `streamable_http`"""
|
||||
"""传输方式,可选 `stdio`、`streamable_http` 或 `sse`"""
|
||||
|
||||
command: str = Field(
|
||||
default="",
|
||||
@@ -2483,6 +2521,9 @@ class MCPServerItemConfig(ConfigBase):
|
||||
if self.transport == "streamable_http" and not self.url.strip():
|
||||
raise ValueError(f"MCP 服务器 {self.name} 使用 streamable_http 时必须填写 url")
|
||||
|
||||
if self.transport == "sse" and not self.url.strip():
|
||||
raise ValueError(f"MCP 服务器 {self.name} 使用 sse 时必须填写 url")
|
||||
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = get_logger("emoji")
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
DATA_DIR = PROJECT_ROOT / "data"
|
||||
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
|
||||
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
|
||||
@@ -215,6 +215,17 @@ def _is_available_emoji_record(record: Images) -> bool:
|
||||
return record_path.exists() and record_path.is_file()
|
||||
|
||||
|
||||
def _resolve_existing_emoji_path(raw_path: str | Path | None) -> Optional[Path]:
|
||||
"""将表情包路径归一化;路径为空或不存在时返回 ``None``。"""
|
||||
if not raw_path:
|
||||
return None
|
||||
|
||||
record_path = Path(raw_path).absolute().resolve()
|
||||
if not record_path.exists() or not record_path.is_file():
|
||||
return None
|
||||
return record_path
|
||||
|
||||
|
||||
def _is_vlm_task_configured() -> bool:
|
||||
"""判断是否配置了可用于表情包识别和审核的视觉模型任务。"""
|
||||
|
||||
@@ -244,6 +255,7 @@ class EmojiManager:
|
||||
|
||||
self._emoji_num: int = 0
|
||||
self.emojis: list[MaiEmoji] = []
|
||||
self._known_emoji_file_paths: set[Path] = set()
|
||||
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
|
||||
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._reload_callback_registered: bool = False
|
||||
@@ -254,9 +266,9 @@ class EmojiManager:
|
||||
logger.info("启动表情包管理器")
|
||||
|
||||
def reload_runtime_config(self) -> None:
|
||||
"""响应配置热重载,唤醒维护循环以尽快应用最新配置。"""
|
||||
"""响应配置热重载,重置维护循环等待时间以应用最新配置。"""
|
||||
self._maintenance_wakeup_event.set()
|
||||
logger.info("[配置热重载] Emoji 模块配置已更新,将立即应用到维护循环")
|
||||
logger.info("[配置热重载] Emoji 模块配置已更新,将按新的检查间隔等待后执行维护")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""清理 EmojiManager 生命周期资源。"""
|
||||
@@ -793,6 +805,7 @@ class EmojiManager:
|
||||
emoji_replace_prompt_template.add_context("emoji_num", str(self._emoji_num))
|
||||
emoji_replace_prompt_template.add_context("emoji_num_max", str(global_config.emoji.max_reg_num))
|
||||
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
|
||||
emoji_replace_prompt_template.add_context("description", new_emoji.description or "无描述")
|
||||
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
|
||||
|
||||
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
@@ -902,11 +915,10 @@ class EmojiManager:
|
||||
# 表情包审查
|
||||
if global_config.emoji.content_filtration:
|
||||
try:
|
||||
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
|
||||
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
|
||||
review_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
|
||||
review_prompt = await prompt_manager.render_prompt(review_prompt_template)
|
||||
filtration_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
filtration_prompt,
|
||||
review_prompt,
|
||||
image_base64,
|
||||
image_format,
|
||||
)
|
||||
@@ -948,33 +960,73 @@ class EmojiManager:
|
||||
|
||||
def check_emoji_file_integrity(self) -> None:
|
||||
"""
|
||||
检查表情包完整性,删除文件缺失的表情包记录
|
||||
检查表情包文件和数据库注册记录的一致性。
|
||||
|
||||
数据库记录存在但文件缺失时删除数据库记录;文件存在但没有数据库记录时删除文件。
|
||||
"""
|
||||
logger.info("[完整性检查] 开始检查表情包文件完整性...")
|
||||
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
|
||||
removal_count = 0
|
||||
for emoji in self.emojis:
|
||||
if not emoji.full_path.exists():
|
||||
logger.warning(f"[完整性检查] 表情包文件缺失,准备修改记录: {emoji.file_name}")
|
||||
to_delete_emojis.append((emoji, False))
|
||||
if not emoji.description:
|
||||
logger.warning(f"[完整性检查] 表情包记录缺失描述,准备删除记录: {emoji.file_name}")
|
||||
to_delete_emojis.append((emoji, True))
|
||||
_ensure_directories()
|
||||
logger.info("[完整性检查] 开始检查表情包文件和注册记录一致性...")
|
||||
tracked_paths: set[Path] = set()
|
||||
record_removal_count = 0
|
||||
file_removal_count = 0
|
||||
available_emojis: list[MaiEmoji] = []
|
||||
|
||||
for emoji, is_description_empty in to_delete_emojis:
|
||||
if self.delete_emoji(emoji, is_description_empty):
|
||||
self.emojis.remove(emoji)
|
||||
self._emoji_num -= 1
|
||||
removal_count += 1
|
||||
logger.info(f"[完整性检查] 成功删除缺失文件的表情包记录: {emoji.file_name}")
|
||||
else:
|
||||
logger.error(f"[完整性检查] 删除缺失文件的表情包记录失败: {emoji.file_name}")
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_type=ImageType.EMOJI)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
record_path = _resolve_existing_emoji_path(record.full_path)
|
||||
if record.no_file_flag or record_path is None:
|
||||
logger.warning(
|
||||
f"[完整性检查] 表情包数据库记录缺少实际文件,删除数据库记录: id={record.id}, path={record.full_path}"
|
||||
)
|
||||
session.delete(record)
|
||||
record_removal_count += 1
|
||||
continue
|
||||
if not record.is_registered or record.is_banned:
|
||||
tracked_paths.add(record_path)
|
||||
continue
|
||||
try:
|
||||
available_emojis.append(MaiEmoji.from_db_instance(record))
|
||||
tracked_paths.add(record_path)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"[完整性检查] 加载表情包记录时出错,将删除异常记录: {exc}\n记录ID: {record.id}, 路径: {record.full_path}"
|
||||
)
|
||||
session.delete(record)
|
||||
record_removal_count += 1
|
||||
|
||||
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
resolved_file = emoji_file.absolute().resolve()
|
||||
if resolved_file in tracked_paths:
|
||||
continue
|
||||
try:
|
||||
emoji_file.unlink()
|
||||
file_removal_count += 1
|
||||
logger.warning(f"[完整性检查] 表情包文件缺少数据库记录,删除文件: {emoji_file}")
|
||||
except Exception as exc:
|
||||
logger.error(f"[完整性检查] 删除无注册记录的表情包文件失败: {emoji_file}, error={exc}")
|
||||
|
||||
self.emojis = available_emojis
|
||||
self._known_emoji_file_paths = tracked_paths
|
||||
self._emoji_num = len(self.emojis)
|
||||
logger.info(
|
||||
f"[完整性检查] 表情包完整性检查完成,删除数据库记录 {record_removal_count} 条,删除图片文件 {file_removal_count} 个"
|
||||
)
|
||||
|
||||
async def periodic_emoji_maintenance(self) -> None:
|
||||
"""Run emoji maintenance tasks periodically."""
|
||||
while True:
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
self._maintenance_wakeup_event.clear()
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
_ensure_directories()
|
||||
try:
|
||||
self.check_emoji_file_integrity()
|
||||
@@ -989,24 +1041,21 @@ class EmojiManager:
|
||||
for emoji_file in EMOJI_DIR.iterdir():
|
||||
if not emoji_file.is_file():
|
||||
continue
|
||||
resolved_file = emoji_file.absolute().resolve()
|
||||
if resolved_file in self._known_emoji_file_paths:
|
||||
continue
|
||||
try:
|
||||
register_status = await self.register_emoji_by_filename(emoji_file)
|
||||
except Exception as e:
|
||||
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
|
||||
register_status = "failed"
|
||||
if register_status == "registered":
|
||||
self._known_emoji_file_paths.add(resolved_file)
|
||||
break
|
||||
if register_status == "skipped":
|
||||
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
|
||||
else:
|
||||
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
|
||||
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
|
||||
try:
|
||||
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
self._maintenance_wakeup_event.clear()
|
||||
|
||||
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
|
||||
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""
|
||||
|
||||
@@ -585,6 +585,9 @@ def _parse_tool_arguments(
|
||||
Raises:
|
||||
RespParseException: 当参数无法解析为字典时抛出。
|
||||
"""
|
||||
if not raw_arguments.strip():
|
||||
return {}
|
||||
|
||||
try:
|
||||
if parse_mode == ToolArgumentParseMode.STRICT:
|
||||
arguments: Any = json.loads(raw_arguments)
|
||||
|
||||
@@ -207,7 +207,7 @@ async def handle_tool(
|
||||
sent_message = await send_service._send_to_target_with_message(
|
||||
message_sequence=reply_sequence,
|
||||
stream_id=tool_ctx.runtime.session_id,
|
||||
display_message=segment,
|
||||
processed_plain_text=segment,
|
||||
set_reply=segment_set_quote,
|
||||
reply_message=target_message if segment_set_quote else None,
|
||||
selected_expressions=reply_result.selected_expression_ids or None,
|
||||
|
||||
@@ -13,7 +13,7 @@ from PIL import Image as PILImage
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import _is_vlm_task_configured, emoji_manager
|
||||
from src.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
@@ -38,6 +38,7 @@ _EMOJI_SUB_AGENT_MAX_TOKENS = 240
|
||||
_EMOJI_MAX_CANDIDATE_COUNT = 64
|
||||
_EMOJI_CANDIDATE_TILE_SIZE = 256
|
||||
_EMOJI_SUCCESS_MESSAGE = "表情包发送成功"
|
||||
_EMOJI_VLM_NOT_CONFIGURED_MESSAGE = "错误,没有配置视觉模型,无法使用表情包功能"
|
||||
|
||||
|
||||
class EmojiSelectionResult(BaseModel):
|
||||
@@ -298,6 +299,13 @@ def _resolve_emoji_selector_model_task_name() -> str:
|
||||
return "vlm"
|
||||
|
||||
|
||||
def _is_missing_visual_model_error(exc: Exception) -> bool:
|
||||
"""判断是否为未配置视觉模型导致的选择失败。"""
|
||||
|
||||
error_text = str(exc)
|
||||
return _EMOJI_VLM_NOT_CONFIGURED_MESSAGE in error_text or "未找到名为 '' 的模型" in error_text
|
||||
|
||||
|
||||
async def _select_emoji_with_sub_agent(
|
||||
tool_ctx: BuiltinToolRuntimeContext,
|
||||
reasoning: str,
|
||||
@@ -351,13 +359,17 @@ async def _select_emoji_with_sub_agent(
|
||||
request_messages.append(candidate_llm_message)
|
||||
serialized_request_messages = serialize_prompt_messages(request_messages)
|
||||
|
||||
model_task_name = _resolve_emoji_selector_model_task_name()
|
||||
if model_task_name == "vlm" and not _is_vlm_task_configured():
|
||||
raise RuntimeError(_EMOJI_VLM_NOT_CONFIGURED_MESSAGE)
|
||||
|
||||
selection_started_at = datetime.now()
|
||||
response = await tool_ctx.runtime.run_sub_agent(
|
||||
context_message_limit=_EMOJI_SUB_AGENT_CONTEXT_LIMIT,
|
||||
system_prompt=system_prompt,
|
||||
extra_messages=[prompt_message, candidate_message],
|
||||
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
|
||||
model_task_name=_resolve_emoji_selector_model_task_name(),
|
||||
model_task_name=model_task_name,
|
||||
)
|
||||
selection_duration_ms = round((datetime.now() - selection_started_at).total_seconds() * 1000, 2)
|
||||
|
||||
@@ -448,7 +460,10 @@ async def handle_tool(
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception(f"{tool_ctx.runtime.log_prefix} 发送表情包时发生异常: {exc}")
|
||||
structured_result["message"] = f"发送表情包时发生异常:{exc}"
|
||||
if _is_missing_visual_model_error(exc):
|
||||
structured_result["message"] = _EMOJI_VLM_NOT_CONFIGURED_MESSAGE
|
||||
else:
|
||||
structured_result["message"] = f"发送表情包时发生异常:{exc}"
|
||||
return tool_ctx.build_failure_result(
|
||||
invocation.tool_name,
|
||||
structured_result["message"],
|
||||
|
||||
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("maisaka_reasoning_engine")
|
||||
|
||||
TIMING_GATE_CONTEXT_DROP_HEAD_RATIO = 0.7
|
||||
TIMING_GATE_MAX_TOKENS = 384
|
||||
TIMING_GATE_MAX_ATTEMPTS = 3
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
HISTORY_SILENT_TOOL_NAMES = {"finish"}
|
||||
@@ -140,7 +139,6 @@ class MaisakaReasoningEngine:
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=None,
|
||||
max_tokens=TIMING_GATE_MAX_TOKENS,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class MCPServerRuntimeConfig:
|
||||
"""单个 MCP 服务器的运行时配置。"""
|
||||
|
||||
name: str
|
||||
transport: Literal["stdio", "streamable_http"] = "stdio"
|
||||
transport: Literal["stdio", "streamable_http", "sse"] = "stdio"
|
||||
command: str = ""
|
||||
args: list[str] = field(default_factory=list)
|
||||
env: dict[str, str] = field(default_factory=dict)
|
||||
@@ -65,13 +65,15 @@ class MCPServerRuntimeConfig:
|
||||
"""返回当前服务器的传输类型。
|
||||
|
||||
Returns:
|
||||
str: ``stdio``、``streamable_http`` 或 ``unknown``。
|
||||
str: ``stdio``、``streamable_http``、``sse`` 或 ``unknown``。
|
||||
"""
|
||||
|
||||
if self.transport == "stdio" and self.command:
|
||||
return "stdio"
|
||||
if self.transport == "streamable_http" and self.url:
|
||||
return "streamable_http"
|
||||
if self.transport == "sse" and self.url:
|
||||
return "sse"
|
||||
return "unknown"
|
||||
|
||||
def build_http_headers(self) -> dict[str, str]:
|
||||
|
||||
@@ -43,16 +43,26 @@ try:
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
try:
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
SSE_AVAILABLE = True
|
||||
except ImportError:
|
||||
SSE_AVAILABLE = False
|
||||
sse_client = None # type: ignore[assignment]
|
||||
|
||||
MCP_AVAILABLE = True
|
||||
STREAMABLE_HTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
MCP_AVAILABLE = False
|
||||
STREAMABLE_HTTP_AVAILABLE = False
|
||||
SSE_AVAILABLE = False
|
||||
ClientSession = None # type: ignore[assignment,misc]
|
||||
StdioServerParameters = None # type: ignore[assignment,misc]
|
||||
mcp_types = None # type: ignore[assignment]
|
||||
stdio_client = None # type: ignore[assignment]
|
||||
streamable_http_client = None # type: ignore[assignment]
|
||||
sse_client = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class MCPConnection:
|
||||
@@ -139,6 +149,8 @@ class MCPConnection:
|
||||
return await self._connect_stdio()
|
||||
if self.config.transport_type == "streamable_http":
|
||||
return await self._connect_streamable_http()
|
||||
if self.config.transport_type == "sse":
|
||||
return await self._connect_sse()
|
||||
|
||||
raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}")
|
||||
|
||||
@@ -184,16 +196,53 @@ class MCPConnection:
|
||||
self._session_id_getter = session_id_getter
|
||||
return read_stream, write_stream
|
||||
|
||||
def _build_http_client(self) -> httpx.AsyncClient:
|
||||
"""构建 Streamable HTTP 使用的 `httpx` 客户端。
|
||||
async def _connect_sse(self) -> tuple[Any, Any]:
|
||||
"""建立 SSE 传输连接。
|
||||
|
||||
Returns:
|
||||
tuple[Any, Any]: 读写流对象。
|
||||
"""
|
||||
|
||||
if not SSE_AVAILABLE or sse_client is None:
|
||||
raise ImportError("当前环境未安装可用的 MCP SSE 客户端")
|
||||
if not self.config.url:
|
||||
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置")
|
||||
|
||||
read_stream, write_stream = await self._exit_stack.enter_async_context(
|
||||
sse_client(
|
||||
url=self.config.url,
|
||||
headers=self.config.build_http_headers(),
|
||||
timeout=self.config.http_timeout_seconds,
|
||||
sse_read_timeout=self.config.read_timeout_seconds,
|
||||
httpx_client_factory=self._build_http_client,
|
||||
)
|
||||
)
|
||||
return read_stream, write_stream
|
||||
|
||||
def _build_http_client(
|
||||
self,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""构建 httpx 客户端。
|
||||
|
||||
Args:
|
||||
headers: 合并到配置请求头的额外请求头。
|
||||
timeout: 覆盖的 httpx 超时配置。
|
||||
auth: 附加认证。
|
||||
|
||||
Returns:
|
||||
httpx.AsyncClient: 预配置的异步 HTTP 客户端。
|
||||
"""
|
||||
|
||||
del auth
|
||||
merged_headers = self.config.build_http_headers()
|
||||
if headers:
|
||||
merged_headers.update(headers)
|
||||
return httpx.AsyncClient(
|
||||
headers=self.config.build_http_headers(),
|
||||
timeout=httpx.Timeout(self.config.http_timeout_seconds),
|
||||
headers=merged_headers,
|
||||
timeout=timeout or httpx.Timeout(self.config.http_timeout_seconds),
|
||||
)
|
||||
|
||||
async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any:
|
||||
|
||||
@@ -309,7 +309,7 @@ class MCPManager:
|
||||
provider_type="mcp",
|
||||
icons=[build_tool_icon(item) for item in getattr(tool, "icons", []) or []],
|
||||
annotation=build_tool_annotation(getattr(tool, "annotations", None)),
|
||||
metadata={"server_name": server_name} | getattr(tool, "meta", {}),
|
||||
metadata={"server_name": server_name} | (getattr(tool, "meta", {}) or {}),
|
||||
)
|
||||
)
|
||||
return tool_specs
|
||||
|
||||
@@ -190,7 +190,7 @@ class RuntimeCoreCapabilityMixin:
|
||||
content=command,
|
||||
stream_id=stream_id,
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
display_message=str(args.get("display_message", "")),
|
||||
processed_plain_text=str(args.get("processed_plain_text", "")),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
maisaka_source_kind=maisaka_source_kind,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ class RuntimeCoreCapabilityMixin:
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=stream_id,
|
||||
display_message=str(args.get("display_message", "")),
|
||||
processed_plain_text=str(args.get("processed_plain_text", "")),
|
||||
typing=bool(args.get("typing", False)),
|
||||
storage_message=bool(args.get("storage_message", True)),
|
||||
sync_to_maisaka_history=sync_to_maisaka_history,
|
||||
|
||||
@@ -296,11 +296,13 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
def _serialize_messages(messages: list) -> List[Any]:
|
||||
def _serialize_messages(messages: list, include_binary_data: bool = True) -> List[Any]:
|
||||
result: List[Any] = []
|
||||
for msg in messages:
|
||||
if all(hasattr(msg, attr) for attr in ("message_id", "timestamp", "platform", "message_info", "raw_message")):
|
||||
result.append(dict(PluginMessageUtils._session_message_to_dict(msg)))
|
||||
result.append(
|
||||
dict(PluginMessageUtils._session_message_to_dict(msg, include_binary_data=include_binary_data))
|
||||
)
|
||||
elif hasattr(msg, "model_dump"):
|
||||
result.append(msg.model_dump())
|
||||
elif hasattr(msg, "__dict__"):
|
||||
@@ -321,7 +323,12 @@ class RuntimeDataCapabilityMixin:
|
||||
message_id=message_id,
|
||||
chat_id=str(args.get("chat_id") or args.get("stream_id") or "").strip() or None,
|
||||
)
|
||||
serialized_message = self._serialize_messages([message])[0] if message is not None else None
|
||||
include_binary_data = bool(args.get("include_binary_data", False))
|
||||
serialized_message = (
|
||||
self._serialize_messages([message], include_binary_data=include_binary_data)[0]
|
||||
if message is not None
|
||||
else None
|
||||
)
|
||||
return {"success": True, "message": serialized_message}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.message.get_by_id] 执行失败: {e}", exc_info=True)
|
||||
@@ -338,7 +345,13 @@ class RuntimeDataCapabilityMixin:
|
||||
limit_mode=args.get("limit_mode", "latest"),
|
||||
filter_mai=args.get("filter_mai", False),
|
||||
)
|
||||
return {"success": True, "messages": self._serialize_messages(messages)}
|
||||
return {
|
||||
"success": True,
|
||||
"messages": self._serialize_messages(
|
||||
messages,
|
||||
include_binary_data=bool(args.get("include_binary_data", False)),
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -360,7 +373,13 @@ class RuntimeDataCapabilityMixin:
|
||||
filter_mai=args.get("filter_mai", False),
|
||||
filter_command=args.get("filter_command", False),
|
||||
)
|
||||
return {"success": True, "messages": self._serialize_messages(messages)}
|
||||
return {
|
||||
"success": True,
|
||||
"messages": self._serialize_messages(
|
||||
messages,
|
||||
include_binary_data=bool(args.get("include_binary_data", False)),
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
@@ -385,7 +404,13 @@ class RuntimeDataCapabilityMixin:
|
||||
limit_mode=args.get("limit_mode", "latest"),
|
||||
filter_mai=args.get("filter_mai", False),
|
||||
)
|
||||
return {"success": True, "messages": self._serialize_messages(messages)}
|
||||
return {
|
||||
"success": True,
|
||||
"messages": self._serialize_messages(
|
||||
messages,
|
||||
include_binary_data=bool(args.get("include_binary_data", False)),
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
@@ -56,12 +56,14 @@ class MessageDict(TypedDict, total=False):
|
||||
session_id: str
|
||||
reply_to: Optional[str]
|
||||
processed_plain_text: Optional[str]
|
||||
display_message: Optional[str]
|
||||
|
||||
|
||||
class PluginMessageUtils:
|
||||
@staticmethod
|
||||
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
|
||||
def _message_sequence_to_dict(
|
||||
message_sequence: MessageSequence,
|
||||
include_binary_data: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""将消息组件序列转换为插件运行时使用的字典结构。
|
||||
|
||||
Args:
|
||||
@@ -70,10 +72,16 @@ class PluginMessageUtils:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
|
||||
"""
|
||||
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
|
||||
return [
|
||||
PluginMessageUtils._component_to_dict(component, include_binary_data=include_binary_data)
|
||||
for component in message_sequence.components
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
|
||||
def _component_to_dict(
|
||||
component: StandardMessageComponents,
|
||||
include_binary_data: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""将单个消息组件转换为插件运行时字典结构。
|
||||
|
||||
Args:
|
||||
@@ -91,8 +99,10 @@ class PluginMessageUtils:
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
if include_binary_data and (
|
||||
binary_data_base64 := PluginMessageUtils._binary_component_to_base64(component, "image")
|
||||
):
|
||||
serialized["binary_data_base64"] = binary_data_base64
|
||||
return serialized
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
@@ -101,8 +111,10 @@ class PluginMessageUtils:
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
if include_binary_data and (
|
||||
binary_data_base64 := PluginMessageUtils._binary_component_to_base64(component, "emoji")
|
||||
):
|
||||
serialized["binary_data_base64"] = binary_data_base64
|
||||
return serialized
|
||||
|
||||
if isinstance(component, VoiceComponent):
|
||||
@@ -111,7 +123,7 @@ class PluginMessageUtils:
|
||||
"data": component.content,
|
||||
"hash": component.binary_hash,
|
||||
}
|
||||
if component.binary_data:
|
||||
if include_binary_data and component.binary_data:
|
||||
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
|
||||
return serialized
|
||||
|
||||
@@ -140,13 +152,53 @@ class PluginMessageUtils:
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return {
|
||||
"type": "forward",
|
||||
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
|
||||
"data": [
|
||||
PluginMessageUtils._forward_component_to_dict(item, include_binary_data=include_binary_data)
|
||||
for item in component.forward_components
|
||||
],
|
||||
}
|
||||
|
||||
return {"type": "dict", "data": component.data}
|
||||
|
||||
@staticmethod
|
||||
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
|
||||
def _binary_component_to_base64(component: Any, image_type: str) -> str:
|
||||
"""将图片或表情组件转换为 Base64,必要时通过 hash 从图片库加载文件。"""
|
||||
|
||||
if component.binary_data:
|
||||
return base64.b64encode(component.binary_data).decode("utf-8")
|
||||
|
||||
binary_hash = str(component.binary_hash or "").strip()
|
||||
if not binary_hash:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
|
||||
target_image_type = ImageType.IMAGE if image_type == "image" else ImageType.EMOJI
|
||||
with get_db_session(auto_commit=False) as db:
|
||||
statement = select(Images).filter_by(image_hash=binary_hash, image_type=target_image_type).limit(1)
|
||||
image_record = db.exec(statement).first()
|
||||
if image_record is None or image_record.no_file_flag:
|
||||
return ""
|
||||
|
||||
image_path = Path(image_record.full_path)
|
||||
if not image_path.is_file():
|
||||
return ""
|
||||
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
||||
except Exception as exc:
|
||||
logger.debug("通过 hash 加载历史媒体失败: type=%s hash=%s error=%s", image_type, binary_hash, exc)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _forward_component_to_dict(
|
||||
component: ForwardComponent,
|
||||
include_binary_data: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""将单个转发节点组件转换为字典结构。
|
||||
|
||||
Args:
|
||||
@@ -160,7 +212,10 @@ class PluginMessageUtils:
|
||||
"user_nickname": component.user_nickname,
|
||||
"user_cardname": component.user_cardname,
|
||||
"message_id": component.message_id,
|
||||
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
|
||||
"content": [
|
||||
PluginMessageUtils._component_to_dict(item, include_binary_data=include_binary_data)
|
||||
for item in component.content
|
||||
],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -341,7 +396,10 @@ class PluginMessageUtils:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
|
||||
def _session_message_to_dict(
|
||||
session_message: SessionMessage,
|
||||
include_binary_data: bool = True,
|
||||
) -> MessageDict:
|
||||
"""
|
||||
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
|
||||
|
||||
@@ -357,7 +415,10 @@ class PluginMessageUtils:
|
||||
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
|
||||
platform=session_message.platform,
|
||||
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
|
||||
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
|
||||
raw_message=PluginMessageUtils._message_sequence_to_dict(
|
||||
session_message.raw_message,
|
||||
include_binary_data=include_binary_data,
|
||||
),
|
||||
is_mentioned=session_message.is_mentioned,
|
||||
is_at=session_message.is_at,
|
||||
is_emoji=session_message.is_emoji,
|
||||
@@ -372,8 +433,6 @@ class PluginMessageUtils:
|
||||
message_dict["reply_to"] = session_message.reply_to
|
||||
if session_message.processed_plain_text is not None:
|
||||
message_dict["processed_plain_text"] = session_message.processed_plain_text
|
||||
if session_message.display_message is not None:
|
||||
message_dict["display_message"] = session_message.display_message
|
||||
|
||||
return message_dict
|
||||
|
||||
@@ -485,8 +544,5 @@ class PluginMessageUtils:
|
||||
session_message.processed_plain_text, str
|
||||
):
|
||||
session_message.processed_plain_text = None
|
||||
session_message.display_message = message_dict.get("display_message")
|
||||
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
|
||||
session_message.display_message = None
|
||||
|
||||
return session_message
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import sys
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
|
||||
from src.llm_models.model_client.plugin_client import PluginLLMClient
|
||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||
@@ -18,7 +18,6 @@ from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
from src.plugin_runtime import (
|
||||
ENV_BLOCKED_PLUGIN_REASONS,
|
||||
ENV_EXTERNAL_PLUGIN_IDS,
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT,
|
||||
ENV_HOST_VERSION,
|
||||
ENV_IPC_ADDRESS,
|
||||
ENV_PLUGIN_DIRS,
|
||||
@@ -1409,12 +1408,9 @@ class PluginRunnerSupervisor:
|
||||
Returns:
|
||||
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
|
||||
"""
|
||||
global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
|
||||
global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
|
||||
return {
|
||||
ENV_BLOCKED_PLUGIN_REASONS: json.dumps(self._blocked_plugin_reasons, ensure_ascii=False),
|
||||
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
|
||||
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
|
||||
ENV_HOST_VERSION: PROTOCOL_VERSION,
|
||||
ENV_IPC_ADDRESS: self._transport.get_address(),
|
||||
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),
|
||||
|
||||
@@ -274,12 +274,12 @@ class PromptManager:
|
||||
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
|
||||
"""
|
||||
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
|
||||
for prompt_name, prompt_file in prompt_templates.items():
|
||||
for prompt_name, prompt_template in prompt_templates.items():
|
||||
try:
|
||||
template, need_save = self._load_prompt_template(prompt_name)
|
||||
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
|
||||
except Exception as exc:
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_template.path}' 时出错,错误信息: {exc}")
|
||||
raise
|
||||
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
if prompt_file.stem in prompt_templates:
|
||||
|
||||
@@ -43,8 +43,6 @@ def _build_readable_line(
|
||||
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
|
||||
normalized: List[SessionMessage] = []
|
||||
for message in messages:
|
||||
if not message.processed_plain_text:
|
||||
message.processed_plain_text = message.display_message or ""
|
||||
normalized.append(message)
|
||||
return normalized
|
||||
|
||||
|
||||
@@ -73,9 +73,9 @@ def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpe
|
||||
"type": "string",
|
||||
"description": "目标会话 ID。",
|
||||
},
|
||||
"display_message": {
|
||||
"processed_plain_text": {
|
||||
"type": "string",
|
||||
"description": "展示层文本。",
|
||||
"description": "可选的预处理纯文本内容。",
|
||||
},
|
||||
"typing": {
|
||||
"type": "boolean",
|
||||
@@ -97,7 +97,7 @@ def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpe
|
||||
required=[
|
||||
"message",
|
||||
"stream_id",
|
||||
"display_message",
|
||||
"processed_plain_text",
|
||||
"typing",
|
||||
"set_reply",
|
||||
"storage_message",
|
||||
@@ -494,7 +494,7 @@ def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160)
|
||||
Returns:
|
||||
str: 适用于日志展示的消息摘要。
|
||||
"""
|
||||
preview_text = (message.processed_plain_text or message.display_message or "").strip()
|
||||
preview_text = (message.processed_plain_text or "").strip()
|
||||
if not preview_text:
|
||||
preview_text = f"[{_describe_message_sequence(message.raw_message)}]"
|
||||
|
||||
@@ -507,7 +507,7 @@ def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160)
|
||||
def _build_outbound_session_message(
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
processed_plain_text: str = "",
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
) -> Optional[SessionMessage]:
|
||||
@@ -516,7 +516,7 @@ def _build_outbound_session_message(
|
||||
Args:
|
||||
message_sequence: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于界面展示的文本内容。
|
||||
processed_plain_text: 可选的预处理纯文本内容。
|
||||
reply_message: 被回复的锚点消息。
|
||||
selected_expressions: 可选的表情候选索引列表。
|
||||
|
||||
@@ -571,7 +571,7 @@ def _build_outbound_session_message(
|
||||
)
|
||||
outbound_message.raw_message = _clone_message_sequence(message_sequence)
|
||||
outbound_message.session_id = target_stream.session_id
|
||||
outbound_message.display_message = display_message
|
||||
outbound_message.processed_plain_text = processed_plain_text.strip() or _build_processed_plain_text(outbound_message)
|
||||
outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
|
||||
message_flags = _detect_outbound_message_flags(outbound_message.raw_message)
|
||||
outbound_message.is_emoji = message_flags["is_emoji"]
|
||||
@@ -619,7 +619,8 @@ async def _prepare_message_for_platform_io(
|
||||
raise ValueError("set_reply=True 时必须提供 reply_message_id")
|
||||
_ensure_reply_component(message, reply_message_id)
|
||||
|
||||
message.processed_plain_text = _build_processed_plain_text(message)
|
||||
if set_reply or not message.processed_plain_text:
|
||||
message.processed_plain_text = _build_processed_plain_text(message)
|
||||
if typing:
|
||||
typing_time = calculate_typing_time(
|
||||
input_string=message.processed_plain_text or "",
|
||||
@@ -935,7 +936,7 @@ async def send_session_message(
|
||||
async def _send_to_target(
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
processed_plain_text: str = "",
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
@@ -950,7 +951,7 @@ async def _send_to_target(
|
||||
await _send_to_target_with_message(
|
||||
message_sequence=message_sequence,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
processed_plain_text=processed_plain_text,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
@@ -967,7 +968,7 @@ async def _send_to_target(
|
||||
async def _send_to_target_with_message(
|
||||
message_sequence: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
processed_plain_text: str = "",
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
@@ -982,7 +983,7 @@ async def _send_to_target_with_message(
|
||||
Args:
|
||||
message_sequence: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于界面展示的文本内容。
|
||||
processed_plain_text: 可选的预处理纯文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
set_reply: 是否在发送时附带引用回复。
|
||||
reply_message: 被回复的消息对象。
|
||||
@@ -1004,7 +1005,7 @@ async def _send_to_target_with_message(
|
||||
outbound_message = _build_outbound_session_message(
|
||||
message_sequence=message_sequence,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
processed_plain_text=processed_plain_text,
|
||||
reply_message=reply_message,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
@@ -1015,7 +1016,7 @@ async def _send_to_target_with_message(
|
||||
"send_service.after_build_message",
|
||||
outbound_message,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
processed_plain_text=processed_plain_text,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
@@ -1068,7 +1069,6 @@ async def text_to_stream_with_message(
|
||||
return await _send_to_target_with_message(
|
||||
message_sequence=MessageSequence(components=[TextComponent(text=text)]),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
@@ -1133,7 +1133,6 @@ async def emoji_to_stream_with_message(
|
||||
return await _send_to_target_with_message(
|
||||
message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
@@ -1202,7 +1201,6 @@ async def image_to_stream(
|
||||
return await _send_to_target(
|
||||
message_sequence=_build_message_sequence_from_custom_message("image", image_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
@@ -1216,7 +1214,7 @@ async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | Dict[str, Any],
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
processed_plain_text: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
set_reply: bool = False,
|
||||
@@ -1231,7 +1229,7 @@ async def custom_to_stream(
|
||||
message_type: 自定义消息类型。
|
||||
content: 自定义消息内容。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于展示的文本内容。
|
||||
processed_plain_text: 可选的预处理纯文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
reply_message: 被回复的消息对象。
|
||||
set_reply: 是否附带引用回复。
|
||||
@@ -1244,7 +1242,7 @@ async def custom_to_stream(
|
||||
return await _send_to_target(
|
||||
message_sequence=_build_message_sequence_from_custom_message(message_type, content),
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
processed_plain_text=processed_plain_text,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
@@ -1258,7 +1256,7 @@ async def custom_to_stream(
|
||||
async def custom_reply_set_to_stream(
|
||||
reply_set: MessageSequence,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
processed_plain_text: str = "",
|
||||
typing: bool = False,
|
||||
reply_message: Optional[MaiMessage] = None,
|
||||
set_reply: bool = False,
|
||||
@@ -1272,7 +1270,7 @@ async def custom_reply_set_to_stream(
|
||||
Args:
|
||||
reply_set: 待发送的消息组件序列。
|
||||
stream_id: 目标会话 ID。
|
||||
display_message: 用于展示的文本内容。
|
||||
processed_plain_text: 可选的预处理纯文本内容。
|
||||
typing: 是否显示输入中状态。
|
||||
reply_message: 被回复的消息对象。
|
||||
set_reply: 是否附带引用回复。
|
||||
@@ -1285,7 +1283,7 @@ async def custom_reply_set_to_stream(
|
||||
return await _send_to_target(
|
||||
message_sequence=reply_set,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
processed_plain_text=processed_plain_text,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
|
||||
@@ -134,7 +134,7 @@ def _setup_anti_crawler(app: FastAPI):
|
||||
"basic": t("startup.webui_anti_crawler_mode_basic"),
|
||||
}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, t("startup.webui_anti_crawler_mode_basic"))
|
||||
logger.info(t("startup.webui_anti_crawler_configured", mode_desc=mode_desc))
|
||||
logger.debug(t("startup.webui_anti_crawler_configured", mode_desc=mode_desc))
|
||||
except Exception as e:
|
||||
logger.error(t("startup.webui_anti_crawler_config_failed", error=e), exc_info=True)
|
||||
|
||||
@@ -159,7 +159,7 @@ def _register_api_routes(app: FastAPI):
|
||||
for router in get_all_routers():
|
||||
app.include_router(router)
|
||||
|
||||
logger.info(t("startup.webui_api_routes_registered"))
|
||||
logger.debug(t("startup.webui_api_routes_registered"))
|
||||
except Exception as e:
|
||||
logger.error(t("startup.webui_api_routes_register_failed", error=e), exc_info=True)
|
||||
|
||||
@@ -217,7 +217,7 @@ def _setup_static_files(app: FastAPI):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
logger.info(t("startup.webui_static_files_configured", static_path=static_path))
|
||||
logger.debug(t("startup.webui_static_files_configured", static_path=static_path))
|
||||
|
||||
|
||||
def _resolve_static_path() -> Path | None:
|
||||
@@ -247,6 +247,5 @@ def show_access_token():
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(t("startup.webui_access_token", token=current_token))
|
||||
logger.info(t("startup.webui_access_token_login_hint"))
|
||||
except Exception as e:
|
||||
logger.error(t("startup.webui_access_token_failed", error=e))
|
||||
|
||||
@@ -39,15 +39,12 @@ class ConfigSchemaGenerator:
|
||||
ui_parent = getattr(config_class, "__ui_parent__", "")
|
||||
ui_label = getattr(config_class, "__ui_label__", "")
|
||||
ui_icon = getattr(config_class, "__ui_icon__", "")
|
||||
ui_merge_children = getattr(config_class, "__ui_merge_children__", [])
|
||||
if ui_parent:
|
||||
schema["uiParent"] = ui_parent
|
||||
if ui_label:
|
||||
schema["uiLabel"] = ui_label
|
||||
if ui_icon:
|
||||
schema["uiIcon"] = ui_icon
|
||||
if ui_merge_children:
|
||||
schema["uiMergeChildren"] = list(ui_merge_children)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ class ChatHistoryManager:
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"content": msg.processed_plain_text or "",
|
||||
"timestamp": msg.timestamp.timestamp(),
|
||||
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
@@ -175,11 +175,7 @@ class ChatHistoryManager:
|
||||
|
||||
user_info = target_msg.message_info.user_info
|
||||
if not has_content:
|
||||
content_text = (
|
||||
target_msg.processed_plain_text
|
||||
or target_msg.display_message
|
||||
or ""
|
||||
)
|
||||
content_text = target_msg.processed_plain_text or ""
|
||||
data["target_message_content"] = content_text
|
||||
if not has_sender:
|
||||
data["target_message_sender_id"] = user_info.user_id or ""
|
||||
|
||||
@@ -2,19 +2,21 @@
|
||||
配置管理API路由
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Tuple, Union, get_args, get_origin
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Tuple
|
||||
import types
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import list_prompt_templates
|
||||
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
|
||||
from src.config.config_base import AttributeData
|
||||
from src.config.config_base import AttributeData, ConfigBase
|
||||
from src.config.model_configs import (
|
||||
APIProvider,
|
||||
ModelInfo,
|
||||
@@ -63,6 +65,9 @@ class PromptFileInfo(BaseModel):
|
||||
name: str = Field(..., description="Prompt 文件名")
|
||||
size: int = Field(..., description="文件大小")
|
||||
modified_at: float = Field(..., description="最后修改时间戳")
|
||||
display_name: str = Field(default="", description="Prompt 展示名称")
|
||||
advanced: bool = Field(default=False, description="是否为高级 Prompt")
|
||||
description: str = Field(default="", description="Prompt 描述")
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
@@ -129,6 +134,71 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
return obj
|
||||
|
||||
|
||||
def _coerce_numeric_value(value: Any, target_type: Any) -> Any:
|
||||
"""根据配置字段类型,把旧 WebUI 可能写入的数字字符串还原为数字。"""
|
||||
if target_type is str:
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
if target_type is int:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
parsed_value = float(value.strip())
|
||||
except ValueError:
|
||||
return value
|
||||
if parsed_value.is_integer():
|
||||
return int(parsed_value)
|
||||
return value
|
||||
|
||||
if target_type is float:
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return float(value.strip())
|
||||
except ValueError:
|
||||
return value
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_value_by_annotation(value: Any, annotation: Any) -> Any:
|
||||
"""递归按 ConfigBase 字段注解修正数据类型,避免保存时把数字写成字符串。"""
|
||||
value = _coerce_numeric_value(value, annotation)
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in {Union, types.UnionType}:
|
||||
for candidate_type in args:
|
||||
if candidate_type is type(None):
|
||||
continue
|
||||
coerced_value = _coerce_value_by_annotation(value, candidate_type)
|
||||
if coerced_value != value or type(coerced_value) is not type(value):
|
||||
return coerced_value
|
||||
return value
|
||||
|
||||
if origin in {list, List} and isinstance(value, list) and args:
|
||||
item_type = args[0]
|
||||
return [_coerce_value_by_annotation(item, item_type) for item in value]
|
||||
|
||||
if origin in {dict, Dict} and isinstance(value, dict) and len(args) >= 2:
|
||||
value_type = args[1]
|
||||
return {key: _coerce_value_by_annotation(item, value_type) for key, item in value.items()}
|
||||
|
||||
if isinstance(value, dict) and isinstance(annotation, type) and issubclass(annotation, ConfigBase):
|
||||
return _coerce_config_numeric_values(value, annotation)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _coerce_config_numeric_values(data: Dict[str, Any], config_type: type[ConfigBase]) -> Dict[str, Any]:
|
||||
"""按配置类 schema 统一修正所有数字字段类型。"""
|
||||
for field_name, field_info in config_type.model_fields.items():
|
||||
if field_name in data:
|
||||
data[field_name] = _coerce_value_by_annotation(data[field_name], field_info.annotation)
|
||||
return data
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@@ -147,14 +217,20 @@ async def list_prompt_files():
|
||||
continue
|
||||
|
||||
language = language_dir.name
|
||||
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
|
||||
prompt_files: List[PromptFileInfo] = []
|
||||
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
|
||||
stat = prompt_file.stat()
|
||||
template_info = prompt_template_infos.get(prompt_file.stem)
|
||||
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
|
||||
prompt_files.append(
|
||||
PromptFileInfo(
|
||||
name=prompt_file.name,
|
||||
size=stat.st_size,
|
||||
modified_at=stat.st_mtime,
|
||||
display_name=metadata.display_name if metadata else "",
|
||||
advanced=metadata.advanced if metadata else False,
|
||||
description=metadata.description if metadata else "",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -347,6 +423,8 @@ async def get_model_config():
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
config_data = _coerce_config_numeric_values(config_data, Config)
|
||||
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(AttributeData(), copy.deepcopy(config_data))
|
||||
@@ -370,6 +448,8 @@ async def update_bot_config(config_data: ConfigBody):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
config_data = _coerce_config_numeric_values(config_data, ModelConfig)
|
||||
|
||||
# 验证配置数据
|
||||
try:
|
||||
ModelConfig.from_dict(AttributeData(), copy.deepcopy(config_data))
|
||||
@@ -422,10 +502,13 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
Config.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
|
||||
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), Config)
|
||||
Config.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
config_data = plain_config_data
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
@@ -520,13 +603,14 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
ModelConfig.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
|
||||
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), ModelConfig)
|
||||
ModelConfig.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
models = plain_config_data.get("models", [])
|
||||
orphaned_models: List[str] = [
|
||||
str(model_name)
|
||||
for m in models
|
||||
@@ -539,6 +623,8 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
config_data = plain_config_data
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
|
||||
@@ -10,11 +10,12 @@ from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
EXCLUDE_IDS_QUERY = Query(None, description="需要排除的表达方式 ID")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
|
||||
@@ -28,6 +29,7 @@ class ExpressionResponse(BaseModel):
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
@@ -90,7 +92,61 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
|
||||
"""从最近消息中解析聊天显示名称。"""
|
||||
|
||||
statement = (
|
||||
select(Messages)
|
||||
.where(col(Messages.session_id) == chat_id)
|
||||
.order_by(col(Messages.timestamp).desc())
|
||||
.limit(1)
|
||||
)
|
||||
message = db_session.exec(statement).first()
|
||||
if not message:
|
||||
return None
|
||||
if message.group_id:
|
||||
return message.group_name or f"群聊{message.group_id}"
|
||||
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
|
||||
|
||||
|
||||
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
|
||||
"""从会话记录推断兜底显示名称。"""
|
||||
|
||||
if chat_session.group_id:
|
||||
return f"群聊{chat_session.group_id}"
|
||||
if chat_session.user_id:
|
||||
return f"用户{chat_session.user_id}"
|
||||
return chat_session.session_id
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
|
||||
try:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
return name
|
||||
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
|
||||
return name
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if session:
|
||||
if session.group_id:
|
||||
return f"群聊{session.group_id}"
|
||||
if session.user_id:
|
||||
return f"用户{session.user_id}"
|
||||
return chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
|
||||
"""将表达方式模型转换为响应对象。
|
||||
|
||||
Args:
|
||||
@@ -101,38 +157,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""
|
||||
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||
chat_id = expression.session_id or ""
|
||||
return ExpressionResponse(
|
||||
id=expression.id if expression.id is not None else 0,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=last_active_time,
|
||||
chat_id=expression.session_id or "",
|
||||
chat_id=chat_id,
|
||||
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
|
||||
create_date=create_date,
|
||||
checked=False,
|
||||
rejected=False,
|
||||
modified_by=None,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by.value if expression.modified_by else None,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据聊天 ID 获取聊天名称。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天会话 ID。
|
||||
|
||||
Returns:
|
||||
str: 聊天显示名称,获取失败时返回原始聊天 ID。
|
||||
"""
|
||||
try:
|
||||
session = _chat_manager.get_session_by_session_id(chat_id)
|
||||
if not session:
|
||||
return chat_id
|
||||
name = _chat_manager.get_session_name(chat_id)
|
||||
return name or chat_id
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称。
|
||||
|
||||
@@ -145,8 +184,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
for chat_id in chat_ids:
|
||||
if name := _chat_manager.get_session_name(chat_id):
|
||||
result[chat_id] = name
|
||||
result[chat_id] = get_chat_name(chat_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
@@ -176,19 +214,43 @@ async def get_chat_list() -> ChatListResponse:
|
||||
ChatListResponse: 可用于下拉选择的聊天列表。
|
||||
"""
|
||||
try:
|
||||
chat_list = []
|
||||
chat_by_id: Dict[str, ChatInfo] = {}
|
||||
for session_id, session in _chat_manager.sessions.items():
|
||||
chat_name = _chat_manager.get_session_name(session_id) or session_id
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=chat_name,
|
||||
platform=session.platform,
|
||||
is_group=session.is_group_session,
|
||||
)
|
||||
|
||||
with get_db_session() as session:
|
||||
for chat_session in session.exec(select(ChatSession)).all():
|
||||
if chat_session.session_id in chat_by_id:
|
||||
continue
|
||||
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
|
||||
chat_by_id[chat_session.session_id] = ChatInfo(
|
||||
chat_id=chat_session.session_id,
|
||||
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
|
||||
platform=chat_session.platform,
|
||||
is_group=bool(chat_session.group_id),
|
||||
)
|
||||
|
||||
expression_chat_ids = {
|
||||
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
|
||||
}
|
||||
for session_id in expression_chat_ids:
|
||||
if session_id in chat_by_id:
|
||||
continue
|
||||
chat_by_id[session_id] = ChatInfo(
|
||||
chat_id=session_id,
|
||||
chat_name=get_chat_name(session_id, session),
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list = list(chat_by_id.values())
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
@@ -252,7 +314,7 @@ async def get_expression_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
@@ -281,7 +343,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=data)
|
||||
|
||||
@@ -321,7 +383,7 @@ async def create_expression(
|
||||
session.add(expression)
|
||||
session.flush()
|
||||
expression_id = expression.id
|
||||
data = expression_to_response(expression)
|
||||
data = expression_to_response(expression, session)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
|
||||
|
||||
@@ -375,7 +437,7 @@ async def update_expression(
|
||||
db_expression.session_id = update_data["session_id"]
|
||||
db_expression.last_active_time = update_data["last_active_time"]
|
||||
session.add(db_expression)
|
||||
data = expression_to_response(db_expression)
|
||||
data = expression_to_response(db_expression, session)
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
@@ -524,6 +586,22 @@ class ReviewStatsResponse(BaseModel):
|
||||
user_checked: int
|
||||
|
||||
|
||||
def apply_review_filter(statement: Any, filter_type: str) -> Any:
|
||||
"""按审核状态过滤表达方式查询。"""
|
||||
if filter_type == "unchecked":
|
||||
return statement.where(col(Expression.checked).is_(False))
|
||||
if filter_type == "passed":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
|
||||
if filter_type == "rejected":
|
||||
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
|
||||
return statement
|
||||
|
||||
|
||||
def count_expressions(session: Any, statement: Any) -> int:
|
||||
"""统计表达方式查询结果数量。"""
|
||||
return len(session.exec(statement).all())
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""获取审核统计数据。
|
||||
@@ -533,12 +611,24 @@ async def get_review_stats() -> ReviewStatsResponse:
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
total = len(session.exec(select(Expression.id)).all())
|
||||
unchecked = 0
|
||||
passed = 0
|
||||
rejected = 0
|
||||
ai_checked = 0
|
||||
user_checked = 0
|
||||
total = count_expressions(session, select(Expression.id))
|
||||
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
|
||||
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
|
||||
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
|
||||
ai_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.AI,
|
||||
),
|
||||
)
|
||||
user_checked = count_expressions(
|
||||
session,
|
||||
select(Expression.id).where(
|
||||
col(Expression.checked).is_(True),
|
||||
col(Expression.modified_by) == ModifiedBy.USER,
|
||||
),
|
||||
)
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
@@ -571,8 +661,10 @@ async def get_review_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
order: str = Query("latest", description="排序方式: latest/random"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
exclude_ids: Optional[List[int]] = EXCLUDE_IDS_QUERY,
|
||||
) -> ReviewListResponse:
|
||||
"""获取待审核或已审核的表达方式列表。
|
||||
|
||||
@@ -580,17 +672,16 @@ async def get_review_list(
|
||||
page: 页码。
|
||||
page_size: 每页数量。
|
||||
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
|
||||
order: 排序方式,可选 latest 或 random。
|
||||
search: 搜索关键词。
|
||||
chat_id: 聊天 ID 筛选条件。
|
||||
exclude_ids: 需要排除的表达方式 ID。
|
||||
|
||||
Returns:
|
||||
ReviewListResponse: 审核列表响应。
|
||||
"""
|
||||
try:
|
||||
statement = select(Expression)
|
||||
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
statement = statement.where(col(Expression.id) == -1)
|
||||
statement = apply_review_filter(select(Expression), filter_type)
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
@@ -603,11 +694,17 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||
col(Expression.create_time).desc(),
|
||||
)
|
||||
if exclude_ids:
|
||||
statement = statement.where(~col(Expression.id).in_(exclude_ids))
|
||||
|
||||
if order == "random":
|
||||
statement = statement.order_by(func.random())
|
||||
else:
|
||||
# 排序:创建时间倒序
|
||||
statement = statement.order_by(
|
||||
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||
col(Expression.create_time).desc(),
|
||||
)
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
statement = statement.offset(offset).limit(page_size)
|
||||
@@ -615,9 +712,7 @@ async def get_review_list(
|
||||
with get_db_session() as session:
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
count_statement = select(Expression.id)
|
||||
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||
count_statement = apply_review_filter(select(Expression.id), filter_type)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||
@@ -625,7 +720,7 @@ async def get_review_list(
|
||||
if chat_id:
|
||||
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||
total = len(session.exec(count_statement).all())
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
data = [expression_to_response(expr, session) for expr in expressions]
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
@@ -647,7 +742,7 @@ class BatchReviewItem(BaseModel):
|
||||
|
||||
id: int
|
||||
rejected: bool
|
||||
require_unchecked: bool = True # 默认要求未检查状态
|
||||
require_unchecked: bool = True # 前端保留的来源标记,人工审核提交时不再阻断覆盖
|
||||
|
||||
|
||||
class BatchReviewRequest(BaseModel):
|
||||
@@ -706,14 +801,6 @@ async def batch_review_expressions(
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked:
|
||||
results.append(
|
||||
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 更新状态
|
||||
with get_db_session() as session:
|
||||
db_expression = session.exec(
|
||||
@@ -727,6 +814,9 @@ async def batch_review_expressions(
|
||||
)
|
||||
failed += 1
|
||||
continue
|
||||
db_expression.checked = True
|
||||
db_expression.rejected = item.rejected
|
||||
db_expression.modified_by = ModifiedBy.USER
|
||||
db_expression.last_active_time = datetime.now()
|
||||
session.add(db_expression)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user