merge: 同步上游 dev 最新内容

This commit is contained in:
DawnARC
2026-05-06 00:53:11 +08:00
125 changed files with 3069 additions and 1271 deletions

View File

@@ -78,7 +78,7 @@ class ChatManager:
"""初始化聊天管理器"""
try:
await self.load_all_sessions_from_db()
logger.info(f"已加载 {len(self.sessions)} 个会话记录到内存中")
logger.debug(f"已加载 {len(self.sessions)} 个会话记录到内存中")
except Exception as e:
logger.error(f"初始化聊天管理器出现错误: {e}")

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import json
import random
import time
from rich.console import Group, RenderableType
@@ -103,6 +104,24 @@ class BaseMaisakaReplyGenerator:
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
return "你的名字是麦麦。\n是人类。"
@staticmethod
def _select_reply_style() -> str:
"""按配置概率选择本次 replyer 使用的表达风格。"""
personality_config = global_config.personality
reply_style = personality_config.reply_style
candidate_styles = [style.strip() for style in personality_config.multiple_reply_style if style.strip()]
if not candidate_styles:
return reply_style
probability = personality_config.multiple_probability
if probability <= 0:
return reply_style
if random.random() > probability:
return reply_style
return random.choice(candidate_styles)
@staticmethod
def _normalize_content(content: str, limit: int = 500) -> str:
normalized = " ".join((content or "").split())
@@ -293,7 +312,7 @@ class BaseMaisakaReplyGenerator:
group_chat_attention_block=self._build_group_chat_attention_block(session_id),
replyer_at_block=self._build_replyer_at_block(),
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
reply_style=self._select_reply_style(),
)
except Exception:
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"

View File

@@ -8,8 +8,6 @@ import random
import re
import time
import jieba
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
@@ -912,110 +910,3 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [keywords_str] if keywords_str else []
def cut_key_words(concept_name: str) -> list[str]:
"""对概念名称进行jieba分词并过滤掉关键词列表中的关键词"""
concept_name_tokens = list(jieba.cut(concept_name))
# 定义常见连词、停用词与标点
conjunctions = {"", "", "", "", "以及", "并且", "而且", "", "或者", ""}
stop_words = {
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"而且",
"或者",
"",
"以及",
}
chinese_punctuations = set(",。!?、;:()【】《》“”‘’—…·-——,.!?;:()[]<>'\"/\\")
# 清理空白并初步过滤纯标点
cleaned_tokens = []
for tok in concept_name_tokens:
t = tok.strip()
if not t:
continue
# 去除纯标点
if all(ch in chinese_punctuations for ch in t):
continue
cleaned_tokens.append(t)
# 合并连词两侧的词(仅当两侧都存在且不是标点/停用词时)
merged_tokens = []
i = 0
n = len(cleaned_tokens)
while i < n:
tok = cleaned_tokens[i]
if tok in conjunctions and merged_tokens and i + 1 < n:
left = merged_tokens[-1]
right = cleaned_tokens[i + 1]
# 左右都需要是有效词
if (
left
and right
and left not in conjunctions
and right not in conjunctions
and left not in stop_words
and right not in stop_words
and not all(ch in chinese_punctuations for ch in left)
and not all(ch in chinese_punctuations for ch in right)
):
# 合并为一个新词,并替换掉左侧与跳过右侧
combined = f"{left}{tok}{right}"
merged_tokens[-1] = combined
i += 2
continue
# 常规推进
merged_tokens.append(tok)
i += 1
# 二次过滤:去除停用词、单字符纯标点与无意义项
result_tokens = []
seen = set()
# ban_words = set(getattr(global_config.memory, "memory_ban_words", []) or [])
for tok in merged_tokens:
if tok in conjunctions:
# 独立连词丢弃
continue
if tok in stop_words:
continue
# if tok in ban_words:
# continue
if all(ch in chinese_punctuations for ch in tok):
continue
if tok.strip() == "":
continue
if tok not in seen:
seen.add(tok)
result_tokens.append(tok)
filtered_concept_name_tokens = result_tokens
return filtered_concept_name_tokens

View File

@@ -79,7 +79,6 @@ class BufferCLI:
)
message.raw_message = MessageSequence([TextComponent(text=user_text)])
message.processed_plain_text = user_text
message.display_message = user_text
message.initialized = True
return message

View File

@@ -59,7 +59,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
self.reply_to: Optional[str] = None
self.processed_plain_text: Optional[str] = None
self.display_message: Optional[str] = None
self.raw_message: MessageSequence
@classmethod
@@ -86,7 +85,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
obj.reply_to = db_record.reply_to
obj.session_id = db_record.session_id
obj.processed_plain_text = db_record.processed_plain_text
obj.display_message = db_record.display_message
obj.raw_message = MessageUtils.from_db_record_msg_to_MaiSeq(db_record.raw_content)
return obj
@@ -113,7 +111,6 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
is_notify=self.is_notify,
raw_content=MessageUtils.from_MaiSeq_to_db_record_msg(self.raw_message),
processed_plain_text=self.processed_plain_text,
display_message=self.display_message,
additional_config=additional_config,
)

View File

@@ -51,7 +51,6 @@ class Messages(SQLModel, table=True):
# 消息内容
raw_content: bytes = Field(sa_column=Column(LargeBinary)) # msgpack后的原始消息内容
processed_plain_text: Optional[str] = Field(default=None) # 平面化处理后的纯文本消息
display_message: Optional[str] = Field(default=None) # 显示的消息内容被放入Prompt
# 其他配置
additional_config: Optional[str] = Field(default=None) # 额外配置JSON格式存储

View File

@@ -8,12 +8,14 @@ from .registry import MigrationRegistry
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
from .schema import SQLiteSchemaInspector
from .v2_to_v3 import migrate_v2_to_v3
from .v3_to_v4 import migrate_v3_to_v4
from .version_store import SQLiteUserVersionStore
EMPTY_SCHEMA_VERSION = 0
LEGACY_V1_SCHEMA_VERSION = 1
V2_SCHEMA_VERSION = 2
LATEST_SCHEMA_VERSION = 3
V3_SCHEMA_VERSION = 3
LATEST_SCHEMA_VERSION = 4
_LEGACY_V1_EXCLUSIVE_TABLES = (
"chat_streams",
@@ -78,9 +80,46 @@ class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
return None
if not snapshot.has_column("person_info", "user_nickname"):
return None
if snapshot.has_column("mai_messages", "display_message"):
return None
return LATEST_SCHEMA_VERSION
class V3SchemaVersionDetector(BaseSchemaVersionDetector):
"""v3 schema 结构探测器。"""
@property
def name(self) -> str:
return "v3_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""检测数据库是否为 v3 结构。"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return None
if not all(snapshot.has_table(table_name) for table_name in _COMMON_MARKER_TABLES):
return None
if snapshot.has_table("action_records"):
return None
if snapshot.has_table("thinking_questions"):
return None
if snapshot.has_column("images", "emotion"):
return None
if not snapshot.has_column("images", "image_hash"):
return None
if not snapshot.has_column("images", "full_path"):
return None
if not snapshot.has_column("images", "image_type"):
return None
if not snapshot.has_column("chat_history", "session_id"):
return None
if not snapshot.has_column("person_info", "user_nickname"):
return None
if not snapshot.has_column("mai_messages", "display_message"):
return None
return V3_SCHEMA_VERSION
class V2SchemaVersionDetector(BaseSchemaVersionDetector):
"""v2 schema 结构探测器。"""
@@ -174,6 +213,7 @@ def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
return [
LatestSchemaVersionDetector(),
V3SchemaVersionDetector(),
V2SchemaVersionDetector(),
LegacyV1SchemaDetector(),
]
@@ -211,10 +251,17 @@ def build_default_migration_registry() -> MigrationRegistry:
),
MigrationStep(
version_from=V2_SCHEMA_VERSION,
version_to=LATEST_SCHEMA_VERSION,
version_to=V3_SCHEMA_VERSION,
name="v2_to_v3",
description="移除废弃表,并将 emoji 标签统一收敛到 description 字段。",
handler=migrate_v2_to_v3,
),
MigrationStep(
version_from=V3_SCHEMA_VERSION,
version_to=LATEST_SCHEMA_VERSION,
name="v3_to_v4",
description="移除 mai_messages.display_message 弃用列。",
handler=migrate_v3_to_v4,
),
]
)

View File

@@ -489,9 +489,6 @@ def _build_legacy_message_additional_config(row: Mapping[str, Any]) -> Optional[
legacy_fields = {
"intercept_message_level": row.get("intercept_message_level"),
"interest_value": row.get("interest_value"),
"key_words": row.get("key_words"),
"key_words_lite": row.get("key_words_lite"),
"priority_info": row.get("priority_info"),
"priority_mode": row.get("priority_mode"),
"selected_expressions": row.get("selected_expressions"),

View File

@@ -0,0 +1,155 @@
"""v3 schema 升级到 v4 的迁移逻辑。"""
from sqlalchemy import text
from sqlalchemy.engine import Connection
from src.common.logger import get_logger
from .exceptions import DatabaseMigrationExecutionError
from .models import MigrationExecutionContext
from .schema import SQLiteSchemaInspector
logger = get_logger("database_migration")
_V3_MESSAGES_BACKUP_TABLE = "__v3_mai_messages_backup"
_V4_MESSAGES_CREATE_SQL = """
CREATE TABLE mai_messages (
id INTEGER NOT NULL,
message_id VARCHAR(255) NOT NULL,
timestamp DATETIME,
platform VARCHAR(100) NOT NULL,
user_id VARCHAR(255) NOT NULL,
user_nickname VARCHAR(255) NOT NULL,
user_cardname VARCHAR(255),
group_id VARCHAR(255),
group_name VARCHAR(255),
is_mentioned BOOLEAN NOT NULL,
is_at BOOLEAN NOT NULL,
session_id VARCHAR(255) NOT NULL,
reply_to VARCHAR(255),
is_emoji BOOLEAN NOT NULL,
is_picture BOOLEAN NOT NULL,
is_command BOOLEAN NOT NULL,
is_notify BOOLEAN NOT NULL,
raw_content BLOB,
processed_plain_text VARCHAR,
additional_config VARCHAR,
PRIMARY KEY (id)
)
"""
_V4_MESSAGES_INDEX_STATEMENTS = (
"CREATE INDEX ix_mai_messages_group_id ON mai_messages (group_id)",
"CREATE INDEX ix_mai_messages_message_id ON mai_messages (message_id)",
"CREATE INDEX ix_mai_messages_platform ON mai_messages (platform)",
"CREATE INDEX ix_mai_messages_session_id ON mai_messages (session_id)",
"CREATE INDEX ix_mai_messages_user_id ON mai_messages (user_id)",
"CREATE INDEX ix_mai_messages_user_nickname ON mai_messages (user_nickname)",
)
def migrate_v3_to_v4(context: MigrationExecutionContext) -> None:
"""执行 v3 到 v4 的 schema 迁移。"""
connection = context.connection
total_records = _count_table_rows(connection, "mai_messages")
context.start_progress(
total_tables=1,
total_records=total_records,
description="v3 -> v4 迁移进度",
table_unit_name="",
record_unit_name="记录",
)
migrated_message_rows = _migrate_messages_table_to_v4(connection)
context.advance_progress(
records=migrated_message_rows,
completed_tables=1,
item_name="mai_messages",
)
logger.info(f"v3 -> v4 数据库迁移完成: mai_messages重建={migrated_message_rows}")
def _count_table_rows(connection: Connection, table_name: str) -> int:
"""统计表记录数,不存在时返回 0。"""
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, table_name):
return 0
row = connection.execute(text(f'SELECT COUNT(*) FROM "{table_name}"')).first()
return int(row[0]) if row else 0
def _migrate_messages_table_to_v4(connection: Connection) -> int:
"""重建 ``mai_messages`` 表并移除弃用的 ``display_message`` 列。"""
schema_inspector = SQLiteSchemaInspector()
if not schema_inspector.table_exists(connection, "mai_messages"):
return 0
if not schema_inspector.get_table_schema(connection, "mai_messages").has_column("display_message"):
return _count_table_rows(connection, "mai_messages")
if schema_inspector.table_exists(connection, _V3_MESSAGES_BACKUP_TABLE):
raise DatabaseMigrationExecutionError(
f"检测到残留备份表 {_V3_MESSAGES_BACKUP_TABLE},无法安全执行 v3 -> v4 mai_messages 迁移。"
)
connection.exec_driver_sql(f'ALTER TABLE "mai_messages" RENAME TO "{_V3_MESSAGES_BACKUP_TABLE}"')
connection.exec_driver_sql(_V4_MESSAGES_CREATE_SQL)
connection.execute(
text(
f"""
INSERT INTO mai_messages (
id,
message_id,
timestamp,
platform,
user_id,
user_nickname,
user_cardname,
group_id,
group_name,
is_mentioned,
is_at,
session_id,
reply_to,
is_emoji,
is_picture,
is_command,
is_notify,
raw_content,
processed_plain_text,
additional_config
)
SELECT
id,
message_id,
timestamp,
platform,
user_id,
user_nickname,
user_cardname,
group_id,
group_name,
is_mentioned,
is_at,
session_id,
reply_to,
is_emoji,
is_picture,
is_command,
is_notify,
raw_content,
COALESCE(NULLIF(processed_plain_text, ''), display_message),
additional_config
FROM "{_V3_MESSAGES_BACKUP_TABLE}"
ORDER BY id
"""
)
)
migrated_rows = _count_table_rows(connection, "mai_messages")
connection.exec_driver_sql(f'DROP TABLE "{_V3_MESSAGES_BACKUP_TABLE}"')
for statement in _V4_MESSAGES_INDEX_STATEMENTS:
connection.exec_driver_sql(statement)
return migrated_rows

View File

@@ -829,20 +829,19 @@ def initialize_logging(verbose: bool = True):
reconfigure_existing_loggers()
# 启动日志清理任务
start_log_cleanup_task(verbose=verbose)
start_log_cleanup_task()
# 只在 verbose=True 时输出详细的初始化信息
if verbose:
logger = get_logger("logger")
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
logger.info("日志系统已初始化:")
logger.info(f" - 控制台级别: {console_level}")
logger.info(f" - 文件级别: {file_level}")
max_log_files = max(1, int(LOG_CONFIG.get("max_log_files", 30) or 30))
log_cleanup_days = max(1, int(LOG_CONFIG.get("log_cleanup_days", 30) or 30))
logger.info(f" - 轮转份数: {max_log_files}个文件|自动清理: {log_cleanup_days}天前的日志")
logger.info(
f"日志系统已初始化:控制台={console_level},文件={file_level}"
f"轮转={max_log_files}个文件,清理={log_cleanup_days}天前"
)
def cleanup_old_logs():
@@ -875,12 +874,8 @@ def cleanup_old_logs():
logger.error(f"清理旧日志文件时出错: {e}")
def start_log_cleanup_task(verbose: bool = True):
"""启动日志清理任务
Args:
verbose: 是否输出启动信息。默认为 True。
"""
def start_log_cleanup_task():
"""启动日志清理任务"""
global _cleanup_task_started
# 防止重复启动清理任务
@@ -897,12 +892,6 @@ def start_log_cleanup_task(verbose: bool = True):
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
cleanup_thread.start()
if verbose:
logger = get_logger("logger")
max_log_files = max(1, int(LOG_CONFIG.get("max_log_files", 30) or 30))
log_cleanup_days = max(1, int(LOG_CONFIG.get("log_cleanup_days", 30) or 30))
logger.info(f"已启动日志清理任务,将自动清理{log_cleanup_days}天前的日志文件(轮转份数限制: {max_log_files}个文件)")
def shutdown_logging():
"""优雅关闭日志系统,释放所有文件句柄"""

View File

@@ -1,8 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
from tomlkit import parse as parse_toml
import json
import logging
import os
import re
@@ -22,6 +26,19 @@ STRICT_ENV_VALUES = {"1", "true", "yes", "on"}
extract_prompt_placeholders = extract_placeholders
@dataclass(frozen=True)
class PromptMetadata:
display_name: str = ""
advanced: bool = False
description: str = ""
@dataclass(frozen=True)
class PromptTemplateInfo:
path: Path
metadata: PromptMetadata
def get_prompts_root(prompts_root: Path | None = None) -> Path:
return (prompts_root or PROMPTS_ROOT).resolve()
@@ -80,24 +97,86 @@ def iter_prompt_files(directory: Path, recursive: bool = True) -> list[Path]:
def _raise_duplicate_prompt_name(name: str, first_path: Path, second_path: Path, prompts_root: Path) -> None:
path_a = first_path.relative_to(prompts_root).as_posix()
path_b = second_path.relative_to(prompts_root).as_posix()
raise ValueError(
t(
"prompt.duplicate_template_name",
name=name,
path_a=first_path.relative_to(prompts_root),
path_b=second_path.relative_to(prompts_root),
path_a=path_a,
path_b=path_b,
)
)
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, Path]:
prompt_paths: dict[str, Path] = {}
def _coerce_metadata(raw_metadata: Any) -> PromptMetadata:
if not isinstance(raw_metadata, dict):
return PromptMetadata()
display_name = raw_metadata.get("display_name", "")
advanced = raw_metadata.get("advanced", False)
description = raw_metadata.get("description", "")
return PromptMetadata(
display_name=display_name if isinstance(display_name, str) else "",
advanced=advanced if isinstance(advanced, bool) else False,
description=description if isinstance(description, str) else "",
)
def _read_metadata_file(metadata_path: Path) -> dict[str, Any]:
if not metadata_path.is_file():
return {}
try:
if metadata_path.suffix == ".json":
metadata = json.loads(metadata_path.read_text(encoding="utf-8"))
else:
metadata = parse_toml(metadata_path.read_text(encoding="utf-8"))
except Exception as exc:
logger.warning("读取 Prompt 元信息文件 %s 失败:%s", metadata_path, exc)
return {}
return dict(metadata) if isinstance(metadata, dict) else {}
def _extract_template_metadata(metadata: dict[str, Any], prompt_name: str) -> dict[str, Any]:
templates = metadata.get("templates")
if isinstance(templates, dict) and isinstance(templates.get(prompt_name), dict):
return dict(templates[prompt_name])
prompt_metadata = metadata.get(prompt_name)
if isinstance(prompt_metadata, dict):
return dict(prompt_metadata)
return metadata if any(key in metadata for key in ("display_name", "advanced", "description")) else {}
def _load_prompt_metadata(prompt_path: Path) -> PromptMetadata:
prompt_name = prompt_path.stem
metadata_sources = (
prompt_path.with_name(f"{prompt_name}.meta.toml"),
prompt_path.with_name(f"{prompt_name}.meta.json"),
prompt_path.parent / ".meta.toml",
prompt_path.parent / ".meta.json",
)
merged_metadata: dict[str, Any] = {}
for metadata_path in reversed(metadata_sources):
raw_metadata = _read_metadata_file(metadata_path)
merged_metadata.update(_extract_template_metadata(raw_metadata, prompt_name))
return _coerce_metadata(merged_metadata)
def _scan_prompt_directory(directory: Path, prompts_root: Path) -> dict[str, PromptTemplateInfo]:
prompt_paths: dict[str, PromptTemplateInfo] = {}
for prompt_path in iter_prompt_files(directory):
prompt_name = prompt_path.stem
existing_path = prompt_paths.get(prompt_name)
if existing_path is not None:
_raise_duplicate_prompt_name(prompt_name, existing_path, prompt_path, prompts_root)
prompt_paths[prompt_name] = prompt_path
existing_info = prompt_paths.get(prompt_name)
if existing_info is not None:
_raise_duplicate_prompt_name(prompt_name, existing_info.path, prompt_path, prompts_root)
prompt_paths[prompt_name] = PromptTemplateInfo(path=prompt_path, metadata=_load_prompt_metadata(prompt_path))
return prompt_paths
@@ -115,11 +194,11 @@ def _iter_locale_candidates(requested_locale: str) -> list[str]:
return locale_candidates
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, Path]:
def list_prompt_templates(locale: str | None = None, prompts_root: Path | None = None) -> dict[str, PromptTemplateInfo]:
resolved_prompts_root = get_prompts_root(prompts_root)
requested_locale = normalize_locale(locale or get_locale())
prompt_paths: dict[str, Path] = {}
prompt_paths: dict[str, PromptTemplateInfo] = {}
for directory in _iter_prompt_template_layers(resolved_prompts_root, requested_locale):
prompt_paths.update(_scan_prompt_directory(directory, resolved_prompts_root))
@@ -149,7 +228,7 @@ def resolve_prompt_path(
else:
prompt_paths = list_prompt_templates(locale=requested_locale, prompts_root=resolved_prompts_root)
if normalized_name in prompt_paths:
return prompt_paths[normalized_name]
return prompt_paths[normalized_name].path
raise FileNotFoundError(t("prompt.template_not_found", locale=requested_locale, name=normalized_name))

View File

@@ -56,9 +56,9 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0-pre.10"
CONFIG_VERSION: str = "8.10.6"
MODEL_CONFIG_VERSION: str = "1.14.8"
MMC_VERSION: str = "1.0.0-pre.11"
CONFIG_VERSION: str = "8.10.7"
MODEL_CONFIG_VERSION: str = "1.15.3"
logger = get_logger("config")

View File

@@ -135,7 +135,6 @@ class ConfigBase(BaseModel, AttrDocBase):
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
__ui_icon__: ClassVar[str] = "" # Tab 图标名称Lucide 图标名)
__ui_merge_children__: ClassVar[List[str]] = [] # 在 WebUI 中并入当前配置卡片展示的子配置字段名
@classmethod
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):

View File

@@ -343,10 +343,10 @@ def try_migrate_legacy_bot_config_dict(data: dict[str, Any]) -> MigrationResult:
reasons: list[str] = []
bot = _as_dict(data.get("bot"))
if bot is not None and isinstance(bot.get("qq_account"), str) and not bot["qq_account"].strip():
bot["qq_account"] = 0
if bot is not None and isinstance(bot.get("qq_account"), int):
bot["qq_account"] = str(bot["qq_account"]) if bot["qq_account"] > 0 else ""
migrated_any = True
reasons.append("bot.qq_account_empty")
reasons.append("bot.qq_account_int_to_string")
chat = _as_dict(data.get("chat"))
if chat is not None and _migrate_chat_talk_value_rules(chat):

View File

@@ -351,6 +351,7 @@ class ModelInfo(ConfigBase):
Gemini 客户端会按自身支持的字段筛选并映射到 GenerateContentConfig、EmbedContentConfig 或音频请求配置中。"""
def model_post_init(self, context: Any = None):
self.model_identifier = self.model_identifier.strip()
if not self.model_identifier:
raise ValueError(t("config.model_identifier_empty_generic"))
if not self.name:
@@ -402,6 +403,7 @@ class TaskConfig(ConfigBase):
"x-widget": "input",
"x-icon": "alert-circle",
"step": 0.1,
"advanced": True,
},
)
"""慢请求阈值(秒),超过此值会输出警告日志"""
@@ -420,15 +422,6 @@ class TaskConfig(ConfigBase):
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
replyer: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
@@ -436,17 +429,7 @@ class ModelTaskConfig(ConfigBase):
"x-icon": "message-square",
},
)
"""首要回复模型配置"""
learner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "graduation-cap",
"advanced": True,
},
)
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
"""回复模型配置"""
planner: TaskConfig = Field(
default_factory=TaskConfig,
@@ -457,6 +440,25 @@ class ModelTaskConfig(ConfigBase):
)
"""规划模型配置"""
utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
learner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "graduation-cap",
"advanced": True,
},
)
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
vlm: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
@@ -471,6 +473,7 @@ class ModelTaskConfig(ConfigBase):
json_schema_extra={
"x-widget": "custom",
"x-icon": "volume-2",
"advanced": True,
},
)
"""语音识别模型配置"""

View File

@@ -4,6 +4,17 @@ import re
from .config_base import ConfigBase, Field
RULE_TYPE_OPTION_DESCRIPTIONS = {
"group": "群聊聊天流item_id 填群号或群聊 ID",
"private": "私聊聊天流item_id 填用户 ID",
}
VISUAL_MODE_OPTION_DESCRIPTIONS = {
"auto": "根据模型信息自动选择文本或多模态模式",
"text": "纯文本模式,不向模型发送视觉输入",
"multimodal": "多模态模式,会向模型发送视觉输入",
}
"""
须知:
1. 本文件中记录了所有的配置项
@@ -19,7 +30,7 @@ class ExampleConfig(ConfigBase):
class BotConfig(ConfigBase):
"""机器人配置类"""
__ui_label__ = "本信息"
__ui_label__ = ""
__ui_icon__ = "bot"
platform: str = Field(
@@ -29,17 +40,19 @@ class BotConfig(ConfigBase):
"x-icon": "wifi",
"x-layout": "inline-right",
"x-input-width": "12rem",
"x-row": "bot-platform-account",
},
)
"""平台"""
qq_account: int = Field(
default=0,
qq_account: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "user",
"x-layout": "inline-right",
"x-input-width": "12rem",
"x-row": "bot-platform-account",
},
)
"""QQ账号"""
@@ -76,6 +89,7 @@ class BotConfig(ConfigBase):
class PersonalityConfig(ConfigBase):
"""人格配置类"""
__ui_parent__ = "bot"
__ui_label__ = "人格"
__ui_icon__ = "user-circle"
@@ -84,6 +98,8 @@ class PersonalityConfig(ConfigBase):
json_schema_extra={
"x-widget": "textarea",
"x-icon": "user-circle",
"x-textarea-min-height": 40,
"x-textarea-rows": 1,
},
)
"""人格建议200字以内描述人格特质和身份特征可以写完整设定。要求第二人称"""
@@ -93,6 +109,8 @@ class PersonalityConfig(ConfigBase):
json_schema_extra={
"x-widget": "textarea",
"x-icon": "message-square",
"x-textarea-min-height": 40,
"x-textarea-rows": 1,
},
)
"""默认表达风格描述麦麦说话的表达风格表达习惯如要修改可以酌情新增内容建议1-2行"""
@@ -137,6 +155,8 @@ class VisualConfig(ConfigBase):
json_schema_extra={
"x-widget": "select",
"x-icon": "git-branch",
"x-option-descriptions": VISUAL_MODE_OPTION_DESCRIPTIONS,
"x-row": "visual-modes",
},
)
"""规划器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
@@ -146,6 +166,8 @@ class VisualConfig(ConfigBase):
json_schema_extra={
"x-widget": "select",
"x-icon": "git-branch",
"x-option-descriptions": VISUAL_MODE_OPTION_DESCRIPTIONS,
"x-row": "visual-modes",
},
)
"""回复器模式auto根据模型信息自动选择text为纯文本模式multimodal为多模态模式"""
@@ -158,7 +180,13 @@ class TalkRulesItem(ConfigBase):
item_id: str = ""
"""用户ID与平台一起留空表示全局"""
rule_type: Literal["group", "private"] = "group"
rule_type: Literal["group", "private"] = Field(
default="group",
json_schema_extra={
"x-widget": "select",
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
},
)
"""聊天流类型group群聊或private私聊"""
time: str = ""
@@ -181,6 +209,7 @@ class ChatConfig(ConfigBase):
json_schema_extra={
"x-widget": "slider",
"x-icon": "message-circle",
"x-row": "talk-values",
"step": 0.1,
},
)
@@ -193,6 +222,7 @@ class ChatConfig(ConfigBase):
json_schema_extra={
"x-widget": "slider",
"x-icon": "message-circle",
"x-row": "talk-values",
"step": 0.1,
},
)
@@ -203,11 +233,19 @@ class ChatConfig(ConfigBase):
json_schema_extra={
"x-widget": "switch",
"x-icon": "at-sign",
"x-row": "reply-switches",
},
)
"""是否启用提及必回复"""
inevitable_at_reply: bool = Field(default=True)
inevitable_at_reply: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "at-sign",
"x-row": "reply-switches",
},
)
"""是否启用at必回复"""
enable_at: bool = Field(
@@ -235,6 +273,9 @@ class ChatConfig(ConfigBase):
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
"x-layout": "inline-right",
"x-input-width": "12rem",
"x-row": "context-sizes",
},
)
"""上下文长度"""
@@ -244,6 +285,9 @@ class ChatConfig(ConfigBase):
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
"x-layout": "inline-right",
"x-input-width": "12rem",
"x-row": "context-sizes",
},
)
"""私聊上下文长度"""
@@ -321,7 +365,8 @@ class ChatConfig(ConfigBase):
class MessageReceiveConfig(ConfigBase):
"""消息接收配置类"""
__ui_parent__ = "response_post_process"
__ui_label__ = "消息接收"
__ui_icon__ = "message-square-text"
image_parse_threshold: int = Field(
default=5,
@@ -386,6 +431,7 @@ class TargetItem(ConfigBase):
json_schema_extra={
"x-widget": "select",
"x-icon": "users",
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
},
)
"""聊天流类型group群聊或private私聊"""
@@ -394,7 +440,7 @@ class TargetItem(ConfigBase):
class MemoryConfig(ConfigBase):
"""记忆配置类"""
__ui_parent__ = "emoji"
__ui_parent__ = "a_memorix"
global_memory: bool = Field(
@@ -1040,6 +1086,7 @@ class LearningItem(ConfigBase):
json_schema_extra={
"x-widget": "select",
"x-icon": "users",
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
},
)
"""聊天流类型group群聊或private私聊"""
@@ -1051,7 +1098,7 @@ class LearningItem(ConfigBase):
"x-icon": "message-square",
},
)
"""是否用表达学习"""
"""是否使用表达"""
enable_learning: bool = Field(
default=True,
@@ -1060,7 +1107,7 @@ class LearningItem(ConfigBase):
"x-icon": "graduation-cap",
},
)
"""是否启用表达优化学习"""
"""是否学习表达"""
enable_jargon_learning: bool = Field(
default=False,
@@ -1069,7 +1116,7 @@ class LearningItem(ConfigBase):
"x-icon": "book",
},
)
"""是否启用jargon学习"""
"""是否学习黑话"""
class ExpressionGroup(ConfigBase):
"""表达互通组配置类,若列表为空代表全局共享"""
@@ -1177,7 +1224,8 @@ class ExpressionConfig(ConfigBase):
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
__ui_parent__ = "emoji"
__ui_label__ = "语音"
__ui_icon__ = "mic"
enable_asr: bool = Field(
default=False,
@@ -1192,8 +1240,8 @@ class VoiceConfig(ConfigBase):
class EmojiConfig(ConfigBase):
"""表情包配置类"""
__ui_label__ = "功能"
__ui_icon__ = "puzzle"
__ui_label__ = "表情包"
__ui_icon__ = "smile"
emoji_send_num: int = Field(
default=25,
@@ -1254,16 +1302,6 @@ class EmojiConfig(ConfigBase):
)
"""是否启用表情包过滤,只有符合该要求的表情包才会被保存"""
filtration_prompt: str = Field(
default="符合公序良俗",
json_schema_extra={
"advanced": True,
"x-widget": "input",
"x-icon": "shield",
},
)
"""表情包过滤要求,只有符合该要求的表情包才会被保存"""
class KeywordRuleConfig(ConfigBase):
"""关键词规则配置类"""
@@ -1314,7 +1352,7 @@ class KeywordRuleConfig(ConfigBase):
class KeywordReactionConfig(ConfigBase):
"""关键词配置类"""
__ui_parent__ = "response_post_process"
__ui_parent__ = "message_receive"
keyword_rules: list[KeywordRuleConfig] = Field(
default_factory=lambda: [],
@@ -1345,9 +1383,8 @@ class KeywordReactionConfig(ConfigBase):
class ResponsePostProcessConfig(ConfigBase):
"""回复后处理配置类"""
__ui_label__ = "处理"
__ui_label__ = "处理"
__ui_icon__ = "settings"
__ui_merge_children__ = ["chinese_typo", "response_splitter"]
enable_response_post_process: bool = Field(
default=True,
@@ -1742,6 +1779,7 @@ class ExtraPromptItem(ConfigBase):
json_schema_extra={
"x-widget": "select",
"x-icon": "users",
"x-option-descriptions": RULE_TYPE_OPTION_DESCRIPTIONS,
},
)
"""聊天流类型group群聊或private私聊"""
@@ -2387,14 +2425,14 @@ class MCPServerItemConfig(ConfigBase):
)
"""是否启用当前 MCP 服务器"""
transport: Literal["stdio", "streamable_http"] = Field(
transport: Literal["stdio", "streamable_http", "sse"] = Field(
default="stdio",
json_schema_extra={
"x-widget": "select",
"x-icon": "shuffle",
},
)
"""传输方式,可选 `stdio``streamable_http`"""
"""传输方式,可选 `stdio``streamable_http` 或 `sse`"""
command: str = Field(
default="",
@@ -2483,6 +2521,9 @@ class MCPServerItemConfig(ConfigBase):
if self.transport == "streamable_http" and not self.url.strip():
raise ValueError(f"MCP 服务器 {self.name} 使用 streamable_http 时必须填写 url")
if self.transport == "sse" and not self.url.strip():
raise ValueError(f"MCP 服务器 {self.name} 使用 sse 时必须填写 url")
return super().model_post_init(context)

View File

@@ -29,7 +29,7 @@ logger = get_logger("emoji")
install(extra_lines=3)
PROJECT_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
PROJECT_ROOT = Path(__file__).resolve().parents[2]
DATA_DIR = PROJECT_ROOT / "data"
EmojiRegisterStatus = Literal["registered", "skipped", "failed"]
EMOJI_DIR = DATA_DIR / "emoji" # 表情包存储目录
@@ -215,6 +215,17 @@ def _is_available_emoji_record(record: Images) -> bool:
return record_path.exists() and record_path.is_file()
def _resolve_existing_emoji_path(raw_path: str | Path | None) -> Optional[Path]:
"""将表情包路径归一化;路径为空或不存在时返回 ``None``。"""
if not raw_path:
return None
record_path = Path(raw_path).absolute().resolve()
if not record_path.exists() or not record_path.is_file():
return None
return record_path
def _is_vlm_task_configured() -> bool:
"""判断是否配置了可用于表情包识别和审核的视觉模型任务。"""
@@ -244,6 +255,7 @@ class EmojiManager:
self._emoji_num: int = 0
self.emojis: list[MaiEmoji] = []
self._known_emoji_file_paths: set[Path] = set()
self._maintenance_wakeup_event: asyncio.Event = asyncio.Event()
self._pending_description_tasks: dict[str, asyncio.Task[None]] = {}
self._reload_callback_registered: bool = False
@@ -254,9 +266,9 @@ class EmojiManager:
logger.info("启动表情包管理器")
def reload_runtime_config(self) -> None:
"""响应配置热重载,唤醒维护循环以尽快应用最新配置。"""
"""响应配置热重载,重置维护循环等待时间以应用最新配置。"""
self._maintenance_wakeup_event.set()
logger.info("[配置热重载] Emoji 模块配置已更新,将立即应用到维护循环")
logger.info("[配置热重载] Emoji 模块配置已更新,将按新的检查间隔等待后执行维护")
def shutdown(self) -> None:
"""清理 EmojiManager 生命周期资源。"""
@@ -793,6 +805,7 @@ class EmojiManager:
emoji_replace_prompt_template.add_context("emoji_num", str(self._emoji_num))
emoji_replace_prompt_template.add_context("emoji_num_max", str(global_config.emoji.max_reg_num))
emoji_replace_prompt_template.add_context("emoji_list", "\n".join(emoji_info_list))
emoji_replace_prompt_template.add_context("description", new_emoji.description or "无描述")
emoji_replace_prompt = await prompt_manager.render_prompt(emoji_replace_prompt_template)
decision_result = await emoji_manager_emotion_judge_llm.generate_response(
@@ -902,11 +915,10 @@ class EmojiManager:
# 表情包审查
if global_config.emoji.content_filtration:
try:
filtration_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
filtration_prompt_template.add_context("demand", global_config.emoji.filtration_prompt)
filtration_prompt = await prompt_manager.render_prompt(filtration_prompt_template)
review_prompt_template = prompt_manager.get_prompt("emoji_content_filtration")
review_prompt = await prompt_manager.render_prompt(review_prompt_template)
filtration_result = await emoji_manager_vlm.generate_response_for_image(
filtration_prompt,
review_prompt,
image_base64,
image_format,
)
@@ -948,33 +960,73 @@ class EmojiManager:
def check_emoji_file_integrity(self) -> None:
"""
检查表情包完整性,删除文件缺失的表情包记录
检查表情包文件和数据库注册记录的一致性。
数据库记录存在但文件缺失时删除数据库记录;文件存在但没有数据库记录时删除文件。
"""
logger.info("[完整性检查] 开始检查表情包文件完整性...")
to_delete_emojis: list[tuple[MaiEmoji, bool]] = []
removal_count = 0
for emoji in self.emojis:
if not emoji.full_path.exists():
logger.warning(f"[完整性检查] 表情包文件缺失,准备修改记录: {emoji.file_name}")
to_delete_emojis.append((emoji, False))
if not emoji.description:
logger.warning(f"[完整性检查] 表情包记录缺失描述,准备删除记录: {emoji.file_name}")
to_delete_emojis.append((emoji, True))
_ensure_directories()
logger.info("[完整性检查] 开始检查表情包文件和注册记录一致性...")
tracked_paths: set[Path] = set()
record_removal_count = 0
file_removal_count = 0
available_emojis: list[MaiEmoji] = []
for emoji, is_description_empty in to_delete_emojis:
if self.delete_emoji(emoji, is_description_empty):
self.emojis.remove(emoji)
self._emoji_num -= 1
removal_count += 1
logger.info(f"[完整性检查] 成功删除缺失文件的表情包记录: {emoji.file_name}")
else:
logger.error(f"[完整性检查] 删除缺失文件的表情包记录失败: {emoji.file_name}")
with get_db_session() as session:
statement = select(Images).filter_by(image_type=ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
record_path = _resolve_existing_emoji_path(record.full_path)
if record.no_file_flag or record_path is None:
logger.warning(
f"[完整性检查] 表情包数据库记录缺少实际文件,删除数据库记录: id={record.id}, path={record.full_path}"
)
session.delete(record)
record_removal_count += 1
continue
if not record.is_registered or record.is_banned:
tracked_paths.add(record_path)
continue
try:
available_emojis.append(MaiEmoji.from_db_instance(record))
tracked_paths.add(record_path)
except Exception as exc:
logger.error(
f"[完整性检查] 加载表情包记录时出错,将删除异常记录: {exc}\n记录ID: {record.id}, 路径: {record.full_path}"
)
session.delete(record)
record_removal_count += 1
logger.info(f"[完整性检查] 表情包文件完整性检查完成,删除了 {removal_count} 条记录")
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
resolved_file = emoji_file.absolute().resolve()
if resolved_file in tracked_paths:
continue
try:
emoji_file.unlink()
file_removal_count += 1
logger.warning(f"[完整性检查] 表情包文件缺少数据库记录,删除文件: {emoji_file}")
except Exception as exc:
logger.error(f"[完整性检查] 删除无注册记录的表情包文件失败: {emoji_file}, error={exc}")
self.emojis = available_emojis
self._known_emoji_file_paths = tracked_paths
self._emoji_num = len(self.emojis)
logger.info(
f"[完整性检查] 表情包完整性检查完成,删除数据库记录 {record_removal_count} 条,删除图片文件 {file_removal_count}"
)
async def periodic_emoji_maintenance(self) -> None:
"""Run emoji maintenance tasks periodically."""
while True:
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
self._maintenance_wakeup_event.clear()
continue
except asyncio.TimeoutError:
self._maintenance_wakeup_event.clear()
_ensure_directories()
try:
self.check_emoji_file_integrity()
@@ -989,24 +1041,21 @@ class EmojiManager:
for emoji_file in EMOJI_DIR.iterdir():
if not emoji_file.is_file():
continue
resolved_file = emoji_file.absolute().resolve()
if resolved_file in self._known_emoji_file_paths:
continue
try:
register_status = await self.register_emoji_by_filename(emoji_file)
except Exception as e:
logger.error(f"[emoji_maintenance] Failed to process {emoji_file.name}: {e}")
register_status = "failed"
if register_status == "registered":
self._known_emoji_file_paths.add(resolved_file)
break
if register_status == "skipped":
logger.debug(f"[emoji_maintenance] Emoji already registered, keep file: {emoji_file.name}")
else:
logger.debug(f"[emoji_maintenance] Emoji not registered, keep file: {emoji_file.name}")
wait_seconds = max(global_config.emoji.check_interval * 60, 0)
try:
await asyncio.wait_for(self._maintenance_wakeup_event.wait(), timeout=wait_seconds)
except asyncio.TimeoutError:
pass
finally:
self._maintenance_wakeup_event.clear()
async def register_emoji_by_filename(self, filename: Path | str) -> EmojiRegisterStatus:
"""Register an emoji file from ``data/emoji`` without moving or deleting it."""

View File

@@ -585,6 +585,9 @@ def _parse_tool_arguments(
Raises:
RespParseException: 当参数无法解析为字典时抛出。
"""
if not raw_arguments.strip():
return {}
try:
if parse_mode == ToolArgumentParseMode.STRICT:
arguments: Any = json.loads(raw_arguments)

View File

@@ -207,7 +207,7 @@ async def handle_tool(
sent_message = await send_service._send_to_target_with_message(
message_sequence=reply_sequence,
stream_id=tool_ctx.runtime.session_id,
display_message=segment,
processed_plain_text=segment,
set_reply=segment_set_quote,
reply_message=target_message if segment_set_quote else None,
selected_expressions=reply_result.selected_expression_ids or None,

View File

@@ -13,7 +13,7 @@ from PIL import Image as PILImage
from PIL import ImageDraw, ImageFont
from pydantic import BaseModel, Field as PydanticField
from src.emoji_system.emoji_manager import emoji_manager
from src.emoji_system.emoji_manager import _is_vlm_task_configured, emoji_manager
from src.emoji_system.maisaka_tool import send_emoji_for_maisaka
from src.common.data_models.image_data_model import MaiEmoji
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
@@ -38,6 +38,7 @@ _EMOJI_SUB_AGENT_MAX_TOKENS = 240
_EMOJI_MAX_CANDIDATE_COUNT = 64
_EMOJI_CANDIDATE_TILE_SIZE = 256
_EMOJI_SUCCESS_MESSAGE = "表情包发送成功"
_EMOJI_VLM_NOT_CONFIGURED_MESSAGE = "错误,没有配置视觉模型,无法使用表情包功能"
class EmojiSelectionResult(BaseModel):
@@ -298,6 +299,13 @@ def _resolve_emoji_selector_model_task_name() -> str:
return "vlm"
def _is_missing_visual_model_error(exc: Exception) -> bool:
"""判断是否为未配置视觉模型导致的选择失败。"""
error_text = str(exc)
return _EMOJI_VLM_NOT_CONFIGURED_MESSAGE in error_text or "未找到名为 '' 的模型" in error_text
async def _select_emoji_with_sub_agent(
tool_ctx: BuiltinToolRuntimeContext,
reasoning: str,
@@ -351,13 +359,17 @@ async def _select_emoji_with_sub_agent(
request_messages.append(candidate_llm_message)
serialized_request_messages = serialize_prompt_messages(request_messages)
model_task_name = _resolve_emoji_selector_model_task_name()
if model_task_name == "vlm" and not _is_vlm_task_configured():
raise RuntimeError(_EMOJI_VLM_NOT_CONFIGURED_MESSAGE)
selection_started_at = datetime.now()
response = await tool_ctx.runtime.run_sub_agent(
context_message_limit=_EMOJI_SUB_AGENT_CONTEXT_LIMIT,
system_prompt=system_prompt,
extra_messages=[prompt_message, candidate_message],
max_tokens=_EMOJI_SUB_AGENT_MAX_TOKENS,
model_task_name=_resolve_emoji_selector_model_task_name(),
model_task_name=model_task_name,
)
selection_duration_ms = round((datetime.now() - selection_started_at).total_seconds() * 1000, 2)
@@ -448,7 +460,10 @@ async def handle_tool(
)
except Exception as exc:
logger.exception(f"{tool_ctx.runtime.log_prefix} 发送表情包时发生异常: {exc}")
structured_result["message"] = f"发送表情包时发生异常:{exc}"
if _is_missing_visual_model_error(exc):
structured_result["message"] = _EMOJI_VLM_NOT_CONFIGURED_MESSAGE
else:
structured_result["message"] = f"发送表情包时发生异常:{exc}"
return tool_ctx.build_failure_result(
invocation.tool_name,
structured_result["message"],

View File

@@ -53,7 +53,6 @@ if TYPE_CHECKING:
logger = get_logger("maisaka_reasoning_engine")
TIMING_GATE_CONTEXT_DROP_HEAD_RATIO = 0.7
TIMING_GATE_MAX_TOKENS = 384
TIMING_GATE_MAX_ATTEMPTS = 3
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
HISTORY_SILENT_TOOL_NAMES = {"finish"}
@@ -140,7 +139,6 @@ class MaisakaReasoningEngine:
system_prompt=system_prompt,
request_kind="timing_gate",
interrupt_flag=None,
max_tokens=TIMING_GATE_MAX_TOKENS,
tool_definitions=tool_definitions,
)

View File

@@ -50,7 +50,7 @@ class MCPServerRuntimeConfig:
"""单个 MCP 服务器的运行时配置。"""
name: str
transport: Literal["stdio", "streamable_http"] = "stdio"
transport: Literal["stdio", "streamable_http", "sse"] = "stdio"
command: str = ""
args: list[str] = field(default_factory=list)
env: dict[str, str] = field(default_factory=dict)
@@ -65,13 +65,15 @@ class MCPServerRuntimeConfig:
"""返回当前服务器的传输类型。
Returns:
str: ``stdio``、``streamable_http`` 或 ``unknown``。
str: ``stdio``、``streamable_http``、``sse`` 或 ``unknown``。
"""
if self.transport == "stdio" and self.command:
return "stdio"
if self.transport == "streamable_http" and self.url:
return "streamable_http"
if self.transport == "sse" and self.url:
return "sse"
return "unknown"
def build_http_headers(self) -> dict[str, str]:

View File

@@ -43,16 +43,26 @@ try:
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
try:
from mcp.client.sse import sse_client
SSE_AVAILABLE = True
except ImportError:
SSE_AVAILABLE = False
sse_client = None # type: ignore[assignment]
MCP_AVAILABLE = True
STREAMABLE_HTTP_AVAILABLE = True
except ImportError:
MCP_AVAILABLE = False
STREAMABLE_HTTP_AVAILABLE = False
SSE_AVAILABLE = False
ClientSession = None # type: ignore[assignment,misc]
StdioServerParameters = None # type: ignore[assignment,misc]
mcp_types = None # type: ignore[assignment]
stdio_client = None # type: ignore[assignment]
streamable_http_client = None # type: ignore[assignment]
sse_client = None # type: ignore[assignment]
class MCPConnection:
@@ -139,6 +149,8 @@ class MCPConnection:
return await self._connect_stdio()
if self.config.transport_type == "streamable_http":
return await self._connect_streamable_http()
if self.config.transport_type == "sse":
return await self._connect_sse()
raise ValueError(f"MCP 服务器 '{self.config.name}' 使用了未知传输类型: {self.config.transport}")
@@ -184,16 +196,53 @@ class MCPConnection:
self._session_id_getter = session_id_getter
return read_stream, write_stream
def _build_http_client(self) -> httpx.AsyncClient:
"""构建 Streamable HTTP 使用的 `httpx` 客户端
async def _connect_sse(self) -> tuple[Any, Any]:
"""建立 SSE 传输连接
Returns:
tuple[Any, Any]: 读写流对象。
"""
if not SSE_AVAILABLE or sse_client is None:
raise ImportError("当前环境未安装可用的 MCP SSE 客户端")
if not self.config.url:
raise ValueError(f"MCP 服务器 '{self.config.name}' 缺少 SSE url 配置")
read_stream, write_stream = await self._exit_stack.enter_async_context(
sse_client(
url=self.config.url,
headers=self.config.build_http_headers(),
timeout=self.config.http_timeout_seconds,
sse_read_timeout=self.config.read_timeout_seconds,
httpx_client_factory=self._build_http_client,
)
)
return read_stream, write_stream
def _build_http_client(
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
"""构建 httpx 客户端。
Args:
headers: 合并到配置请求头的额外请求头。
timeout: 覆盖的 httpx 超时配置。
auth: 附加认证。
Returns:
httpx.AsyncClient: 预配置的异步 HTTP 客户端。
"""
del auth
merged_headers = self.config.build_http_headers()
if headers:
merged_headers.update(headers)
return httpx.AsyncClient(
headers=self.config.build_http_headers(),
timeout=httpx.Timeout(self.config.http_timeout_seconds),
headers=merged_headers,
timeout=timeout or httpx.Timeout(self.config.http_timeout_seconds),
)
async def _create_client_session(self, read_stream: Any, write_stream: Any) -> Any:

View File

@@ -309,7 +309,7 @@ class MCPManager:
provider_type="mcp",
icons=[build_tool_icon(item) for item in getattr(tool, "icons", []) or []],
annotation=build_tool_annotation(getattr(tool, "annotations", None)),
metadata={"server_name": server_name} | getattr(tool, "meta", {}),
metadata={"server_name": server_name} | (getattr(tool, "meta", {}) or {}),
)
)
return tool_specs

View File

@@ -190,7 +190,7 @@ class RuntimeCoreCapabilityMixin:
content=command,
stream_id=stream_id,
storage_message=bool(args.get("storage_message", True)),
display_message=str(args.get("display_message", "")),
processed_plain_text=str(args.get("processed_plain_text", "")),
sync_to_maisaka_history=sync_to_maisaka_history,
maisaka_source_kind=maisaka_source_kind,
)
@@ -228,7 +228,7 @@ class RuntimeCoreCapabilityMixin:
message_type=message_type,
content=content,
stream_id=stream_id,
display_message=str(args.get("display_message", "")),
processed_plain_text=str(args.get("processed_plain_text", "")),
typing=bool(args.get("typing", False)),
storage_message=bool(args.get("storage_message", True)),
sync_to_maisaka_history=sync_to_maisaka_history,

View File

@@ -296,11 +296,13 @@ class RuntimeDataCapabilityMixin:
return {"success": False, "error": str(e)}
@staticmethod
def _serialize_messages(messages: list) -> List[Any]:
def _serialize_messages(messages: list, include_binary_data: bool = True) -> List[Any]:
result: List[Any] = []
for msg in messages:
if all(hasattr(msg, attr) for attr in ("message_id", "timestamp", "platform", "message_info", "raw_message")):
result.append(dict(PluginMessageUtils._session_message_to_dict(msg)))
result.append(
dict(PluginMessageUtils._session_message_to_dict(msg, include_binary_data=include_binary_data))
)
elif hasattr(msg, "model_dump"):
result.append(msg.model_dump())
elif hasattr(msg, "__dict__"):
@@ -321,7 +323,12 @@ class RuntimeDataCapabilityMixin:
message_id=message_id,
chat_id=str(args.get("chat_id") or args.get("stream_id") or "").strip() or None,
)
serialized_message = self._serialize_messages([message])[0] if message is not None else None
include_binary_data = bool(args.get("include_binary_data", False))
serialized_message = (
self._serialize_messages([message], include_binary_data=include_binary_data)[0]
if message is not None
else None
)
return {"success": True, "message": serialized_message}
except Exception as e:
logger.error(f"[cap.message.get_by_id] 执行失败: {e}", exc_info=True)
@@ -338,7 +345,13 @@ class RuntimeDataCapabilityMixin:
limit_mode=args.get("limit_mode", "latest"),
filter_mai=args.get("filter_mai", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
return {
"success": True,
"messages": self._serialize_messages(
messages,
include_binary_data=bool(args.get("include_binary_data", False)),
),
}
except Exception as e:
logger.error(f"[cap.message.get_by_time] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
@@ -360,7 +373,13 @@ class RuntimeDataCapabilityMixin:
filter_mai=args.get("filter_mai", False),
filter_command=args.get("filter_command", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
return {
"success": True,
"messages": self._serialize_messages(
messages,
include_binary_data=bool(args.get("include_binary_data", False)),
),
}
except Exception as e:
logger.error(f"[cap.message.get_by_time_in_chat] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
@@ -385,7 +404,13 @@ class RuntimeDataCapabilityMixin:
limit_mode=args.get("limit_mode", "latest"),
filter_mai=args.get("filter_mai", False),
)
return {"success": True, "messages": self._serialize_messages(messages)}
return {
"success": True,
"messages": self._serialize_messages(
messages,
include_binary_data=bool(args.get("include_binary_data", False)),
),
}
except Exception as e:
logger.error(f"[cap.message.get_recent] 执行失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}

View File

@@ -56,12 +56,14 @@ class MessageDict(TypedDict, total=False):
session_id: str
reply_to: Optional[str]
processed_plain_text: Optional[str]
display_message: Optional[str]
class PluginMessageUtils:
@staticmethod
def _message_sequence_to_dict(message_sequence: MessageSequence) -> List[Dict[str, Any]]:
def _message_sequence_to_dict(
message_sequence: MessageSequence,
include_binary_data: bool = True,
) -> List[Dict[str, Any]]:
"""将消息组件序列转换为插件运行时使用的字典结构。
Args:
@@ -70,10 +72,16 @@ class PluginMessageUtils:
Returns:
List[Dict[str, Any]]: 供插件运行时协议使用的消息段字典列表。
"""
return [PluginMessageUtils._component_to_dict(component) for component in message_sequence.components]
return [
PluginMessageUtils._component_to_dict(component, include_binary_data=include_binary_data)
for component in message_sequence.components
]
@staticmethod
def _component_to_dict(component: StandardMessageComponents) -> Dict[str, Any]:
def _component_to_dict(
component: StandardMessageComponents,
include_binary_data: bool = True,
) -> Dict[str, Any]:
"""将单个消息组件转换为插件运行时字典结构。
Args:
@@ -91,8 +99,10 @@ class PluginMessageUtils:
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
if include_binary_data and (
binary_data_base64 := PluginMessageUtils._binary_component_to_base64(component, "image")
):
serialized["binary_data_base64"] = binary_data_base64
return serialized
if isinstance(component, EmojiComponent):
@@ -101,8 +111,10 @@ class PluginMessageUtils:
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
if include_binary_data and (
binary_data_base64 := PluginMessageUtils._binary_component_to_base64(component, "emoji")
):
serialized["binary_data_base64"] = binary_data_base64
return serialized
if isinstance(component, VoiceComponent):
@@ -111,7 +123,7 @@ class PluginMessageUtils:
"data": component.content,
"hash": component.binary_hash,
}
if component.binary_data:
if include_binary_data and component.binary_data:
serialized["binary_data_base64"] = base64.b64encode(component.binary_data).decode("utf-8")
return serialized
@@ -140,13 +152,53 @@ class PluginMessageUtils:
if isinstance(component, ForwardNodeComponent):
return {
"type": "forward",
"data": [PluginMessageUtils._forward_component_to_dict(item) for item in component.forward_components],
"data": [
PluginMessageUtils._forward_component_to_dict(item, include_binary_data=include_binary_data)
for item in component.forward_components
],
}
return {"type": "dict", "data": component.data}
@staticmethod
def _forward_component_to_dict(component: ForwardComponent) -> Dict[str, Any]:
def _binary_component_to_base64(component: Any, image_type: str) -> str:
"""将图片或表情组件转换为 Base64必要时通过 hash 从图片库加载文件。"""
if component.binary_data:
return base64.b64encode(component.binary_data).decode("utf-8")
binary_hash = str(component.binary_hash or "").strip()
if not binary_hash:
return ""
try:
from pathlib import Path
from sqlmodel import select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
target_image_type = ImageType.IMAGE if image_type == "image" else ImageType.EMOJI
with get_db_session(auto_commit=False) as db:
statement = select(Images).filter_by(image_hash=binary_hash, image_type=target_image_type).limit(1)
image_record = db.exec(statement).first()
if image_record is None or image_record.no_file_flag:
return ""
image_path = Path(image_record.full_path)
if not image_path.is_file():
return ""
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
except Exception as exc:
logger.debug("通过 hash 加载历史媒体失败: type=%s hash=%s error=%s", image_type, binary_hash, exc)
return ""
@staticmethod
def _forward_component_to_dict(
component: ForwardComponent,
include_binary_data: bool = True,
) -> Dict[str, Any]:
"""将单个转发节点组件转换为字典结构。
Args:
@@ -160,7 +212,10 @@ class PluginMessageUtils:
"user_nickname": component.user_nickname,
"user_cardname": component.user_cardname,
"message_id": component.message_id,
"content": [PluginMessageUtils._component_to_dict(item) for item in component.content],
"content": [
PluginMessageUtils._component_to_dict(item, include_binary_data=include_binary_data)
for item in component.content
],
}
@staticmethod
@@ -341,7 +396,10 @@ class PluginMessageUtils:
)
@staticmethod
def _session_message_to_dict(session_message: SessionMessage) -> MessageDict:
def _session_message_to_dict(
session_message: SessionMessage,
include_binary_data: bool = True,
) -> MessageDict:
"""
将 SessionMessage 对象转换为字典格式(复用 MessageSequence.to_dict 方法)
@@ -357,7 +415,10 @@ class PluginMessageUtils:
timestamp=str(session_message.timestamp.timestamp()), # 转换为时间戳字符串
platform=session_message.platform,
message_info=PluginMessageUtils._message_info_to_dict(session_message.message_info),
raw_message=PluginMessageUtils._message_sequence_to_dict(session_message.raw_message),
raw_message=PluginMessageUtils._message_sequence_to_dict(
session_message.raw_message,
include_binary_data=include_binary_data,
),
is_mentioned=session_message.is_mentioned,
is_at=session_message.is_at,
is_emoji=session_message.is_emoji,
@@ -372,8 +433,6 @@ class PluginMessageUtils:
message_dict["reply_to"] = session_message.reply_to
if session_message.processed_plain_text is not None:
message_dict["processed_plain_text"] = session_message.processed_plain_text
if session_message.display_message is not None:
message_dict["display_message"] = session_message.display_message
return message_dict
@@ -485,8 +544,5 @@ class PluginMessageUtils:
session_message.processed_plain_text, str
):
session_message.processed_plain_text = None
session_message.display_message = message_dict.get("display_message")
if session_message.display_message is not None and not isinstance(session_message.display_message, str):
session_message.display_message = None
return session_message

View File

@@ -9,7 +9,7 @@ import os
import sys
from src.common.logger import get_logger
from src.config.config import config_manager, global_config
from src.config.config import global_config
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
from src.llm_models.model_client.plugin_client import PluginLLMClient
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
@@ -18,7 +18,6 @@ from src.platform_io.route_key_factory import RouteKeyFactory
from src.plugin_runtime import (
ENV_BLOCKED_PLUGIN_REASONS,
ENV_EXTERNAL_PLUGIN_IDS,
ENV_GLOBAL_CONFIG_SNAPSHOT,
ENV_HOST_VERSION,
ENV_IPC_ADDRESS,
ENV_PLUGIN_DIRS,
@@ -1409,12 +1408,9 @@ class PluginRunnerSupervisor:
Returns:
Dict[str, str]: 传递给 Runner 进程的环境变量映射。
"""
global_config_snapshot = config_manager.get_global_config().model_dump(mode="json")
global_config_snapshot["model"] = config_manager.get_model_config().model_dump(mode="json")
return {
ENV_BLOCKED_PLUGIN_REASONS: json.dumps(self._blocked_plugin_reasons, ensure_ascii=False),
ENV_EXTERNAL_PLUGIN_IDS: json.dumps(self._external_available_plugins, ensure_ascii=False),
ENV_GLOBAL_CONFIG_SNAPSHOT: json.dumps(global_config_snapshot, ensure_ascii=False),
ENV_HOST_VERSION: PROTOCOL_VERSION,
ENV_IPC_ADDRESS: self._transport.get_address(),
ENV_PLUGIN_DIRS: os.pathsep.join(str(path) for path in self._plugin_dirs),

View File

@@ -274,12 +274,12 @@ class PromptManager:
Exception: 如果在加载过程中出现任何文件操作错误则引发该异常
"""
prompt_templates = list_prompt_templates(prompts_root=PROMPTS_DIR)
for prompt_name, prompt_file in prompt_templates.items():
for prompt_name, prompt_template in prompt_templates.items():
try:
template, need_save = self._load_prompt_template(prompt_name)
self.add_prompt(Prompt(prompt_name=prompt_name, template=template), need_save=need_save)
except Exception as exc:
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {exc}")
logger.error(f"加载 Prompt 文件 '{prompt_template.path}' 时出错,错误信息: {exc}")
raise
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
if prompt_file.stem in prompt_templates:

View File

@@ -43,8 +43,6 @@ def _build_readable_line(
def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]:
normalized: List[SessionMessage] = []
for message in messages:
if not message.processed_plain_text:
message.processed_plain_text = message.display_message or ""
normalized.append(message)
return normalized

View File

@@ -73,9 +73,9 @@ def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpe
"type": "string",
"description": "目标会话 ID。",
},
"display_message": {
"processed_plain_text": {
"type": "string",
"description": "展示层文本",
"description": "可选的预处理纯文本内容",
},
"typing": {
"type": "boolean",
@@ -97,7 +97,7 @@ def register_send_service_hook_specs(registry: HookSpecRegistry) -> List[HookSpe
required=[
"message",
"stream_id",
"display_message",
"processed_plain_text",
"typing",
"set_reply",
"storage_message",
@@ -494,7 +494,7 @@ def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160)
Returns:
str: 适用于日志展示的消息摘要。
"""
preview_text = (message.processed_plain_text or message.display_message or "").strip()
preview_text = (message.processed_plain_text or "").strip()
if not preview_text:
preview_text = f"[{_describe_message_sequence(message.raw_message)}]"
@@ -507,7 +507,7 @@ def _build_outbound_log_preview(message: SessionMessage, max_length: int = 160)
def _build_outbound_session_message(
message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
processed_plain_text: str = "",
reply_message: Optional[MaiMessage] = None,
selected_expressions: Optional[List[int]] = None,
) -> Optional[SessionMessage]:
@@ -516,7 +516,7 @@ def _build_outbound_session_message(
Args:
message_sequence: 待发送的消息组件序列。
stream_id: 目标会话 ID。
display_message: 用于界面展示的文本内容。
processed_plain_text: 可选的预处理纯文本内容。
reply_message: 被回复的锚点消息。
selected_expressions: 可选的表情候选索引列表。
@@ -571,7 +571,7 @@ def _build_outbound_session_message(
)
outbound_message.raw_message = _clone_message_sequence(message_sequence)
outbound_message.session_id = target_stream.session_id
outbound_message.display_message = display_message
outbound_message.processed_plain_text = processed_plain_text.strip() or _build_processed_plain_text(outbound_message)
outbound_message.reply_to = anchor_message.message_id if anchor_message is not None else None
message_flags = _detect_outbound_message_flags(outbound_message.raw_message)
outbound_message.is_emoji = message_flags["is_emoji"]
@@ -619,7 +619,8 @@ async def _prepare_message_for_platform_io(
raise ValueError("set_reply=True 时必须提供 reply_message_id")
_ensure_reply_component(message, reply_message_id)
message.processed_plain_text = _build_processed_plain_text(message)
if set_reply or not message.processed_plain_text:
message.processed_plain_text = _build_processed_plain_text(message)
if typing:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text or "",
@@ -935,7 +936,7 @@ async def send_session_message(
async def _send_to_target(
message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
processed_plain_text: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
@@ -950,7 +951,7 @@ async def _send_to_target(
await _send_to_target_with_message(
message_sequence=message_sequence,
stream_id=stream_id,
display_message=display_message,
processed_plain_text=processed_plain_text,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
@@ -967,7 +968,7 @@ async def _send_to_target(
async def _send_to_target_with_message(
message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
processed_plain_text: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
@@ -982,7 +983,7 @@ async def _send_to_target_with_message(
Args:
message_sequence: 待发送的消息组件序列。
stream_id: 目标会话 ID。
display_message: 用于界面展示的文本内容。
processed_plain_text: 可选的预处理纯文本内容。
typing: 是否显示输入中状态。
set_reply: 是否在发送时附带引用回复。
reply_message: 被回复的消息对象。
@@ -1004,7 +1005,7 @@ async def _send_to_target_with_message(
outbound_message = _build_outbound_session_message(
message_sequence=message_sequence,
stream_id=stream_id,
display_message=display_message,
processed_plain_text=processed_plain_text,
reply_message=reply_message,
selected_expressions=selected_expressions,
)
@@ -1015,7 +1016,7 @@ async def _send_to_target_with_message(
"send_service.after_build_message",
outbound_message,
stream_id=stream_id,
display_message=display_message,
processed_plain_text=processed_plain_text,
typing=typing,
set_reply=set_reply,
storage_message=storage_message,
@@ -1068,7 +1069,6 @@ async def text_to_stream_with_message(
return await _send_to_target_with_message(
message_sequence=MessageSequence(components=[TextComponent(text=text)]),
stream_id=stream_id,
display_message="",
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
@@ -1133,7 +1133,6 @@ async def emoji_to_stream_with_message(
return await _send_to_target_with_message(
message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
stream_id=stream_id,
display_message="",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
@@ -1202,7 +1201,6 @@ async def image_to_stream(
return await _send_to_target(
message_sequence=_build_message_sequence_from_custom_message("image", image_base64),
stream_id=stream_id,
display_message="",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
@@ -1216,7 +1214,7 @@ async def custom_to_stream(
message_type: str,
content: str | Dict[str, Any],
stream_id: str,
display_message: str = "",
processed_plain_text: str = "",
typing: bool = False,
reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
@@ -1231,7 +1229,7 @@ async def custom_to_stream(
message_type: 自定义消息类型。
content: 自定义消息内容。
stream_id: 目标会话 ID。
display_message: 用于展示的文本内容。
processed_plain_text: 可选的预处理纯文本内容。
typing: 是否显示输入中状态。
reply_message: 被回复的消息对象。
set_reply: 是否附带引用回复。
@@ -1244,7 +1242,7 @@ async def custom_to_stream(
return await _send_to_target(
message_sequence=_build_message_sequence_from_custom_message(message_type, content),
stream_id=stream_id,
display_message=display_message,
processed_plain_text=processed_plain_text,
typing=typing,
reply_message=reply_message,
set_reply=set_reply,
@@ -1258,7 +1256,7 @@ async def custom_to_stream(
async def custom_reply_set_to_stream(
reply_set: MessageSequence,
stream_id: str,
display_message: str = "",
processed_plain_text: str = "",
typing: bool = False,
reply_message: Optional[MaiMessage] = None,
set_reply: bool = False,
@@ -1272,7 +1270,7 @@ async def custom_reply_set_to_stream(
Args:
reply_set: 待发送的消息组件序列。
stream_id: 目标会话 ID。
display_message: 用于展示的文本内容。
processed_plain_text: 可选的预处理纯文本内容。
typing: 是否显示输入中状态。
reply_message: 被回复的消息对象。
set_reply: 是否附带引用回复。
@@ -1285,7 +1283,7 @@ async def custom_reply_set_to_stream(
return await _send_to_target(
message_sequence=reply_set,
stream_id=stream_id,
display_message=display_message,
processed_plain_text=processed_plain_text,
typing=typing,
reply_message=reply_message,
set_reply=set_reply,

View File

@@ -134,7 +134,7 @@ def _setup_anti_crawler(app: FastAPI):
"basic": t("startup.webui_anti_crawler_mode_basic"),
}
mode_desc = mode_descriptions.get(anti_crawler_mode, t("startup.webui_anti_crawler_mode_basic"))
logger.info(t("startup.webui_anti_crawler_configured", mode_desc=mode_desc))
logger.debug(t("startup.webui_anti_crawler_configured", mode_desc=mode_desc))
except Exception as e:
logger.error(t("startup.webui_anti_crawler_config_failed", error=e), exc_info=True)
@@ -159,7 +159,7 @@ def _register_api_routes(app: FastAPI):
for router in get_all_routers():
app.include_router(router)
logger.info(t("startup.webui_api_routes_registered"))
logger.debug(t("startup.webui_api_routes_registered"))
except Exception as e:
logger.error(t("startup.webui_api_routes_register_failed", error=e), exc_info=True)
@@ -217,7 +217,7 @@ def _setup_static_files(app: FastAPI):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(t("startup.webui_static_files_configured", static_path=static_path))
logger.debug(t("startup.webui_static_files_configured", static_path=static_path))
def _resolve_static_path() -> Path | None:
@@ -247,6 +247,5 @@ def show_access_token():
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(t("startup.webui_access_token", token=current_token))
logger.info(t("startup.webui_access_token_login_hint"))
except Exception as e:
logger.error(t("startup.webui_access_token_failed", error=e))

View File

@@ -39,15 +39,12 @@ class ConfigSchemaGenerator:
ui_parent = getattr(config_class, "__ui_parent__", "")
ui_label = getattr(config_class, "__ui_label__", "")
ui_icon = getattr(config_class, "__ui_icon__", "")
ui_merge_children = getattr(config_class, "__ui_merge_children__", [])
if ui_parent:
schema["uiParent"] = ui_parent
if ui_label:
schema["uiLabel"] = ui_label
if ui_icon:
schema["uiIcon"] = ui_icon
if ui_merge_children:
schema["uiMergeChildren"] = list(ui_merge_children)
return schema

View File

@@ -112,7 +112,7 @@ class ChatHistoryManager:
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"content": msg.processed_plain_text or "",
"timestamp": msg.timestamp.timestamp(),
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
@@ -175,11 +175,7 @@ class ChatHistoryManager:
user_info = target_msg.message_info.user_info
if not has_content:
content_text = (
target_msg.processed_plain_text
or target_msg.display_message
or ""
)
content_text = target_msg.processed_plain_text or ""
data["target_message_content"] = content_text
if not has_sender:
data["target_message_sender_id"] = user_info.user_id or ""

View File

@@ -2,19 +2,21 @@
配置管理API路由
"""
from pathlib import Path
from typing import Annotated, Any, Dict, List, Tuple, Union, get_args, get_origin
import copy
import os
from pathlib import Path
from typing import Annotated, Any, Dict, List, Tuple
import types
import tomlkit
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
import tomlkit
from src.common.logger import get_logger
from src.common.prompt_i18n import list_prompt_templates
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
from src.config.config_base import AttributeData
from src.config.config_base import AttributeData, ConfigBase
from src.config.model_configs import (
APIProvider,
ModelInfo,
@@ -63,6 +65,9 @@ class PromptFileInfo(BaseModel):
name: str = Field(..., description="Prompt 文件名")
size: int = Field(..., description="文件大小")
modified_at: float = Field(..., description="最后修改时间戳")
display_name: str = Field(default="", description="Prompt 展示名称")
advanced: bool = Field(default=False, description="是否为高级 Prompt")
description: str = Field(default="", description="Prompt 描述")
class PromptCatalogResponse(BaseModel):
@@ -129,6 +134,71 @@ def _toml_to_plain_dict(obj: Any) -> Any:
return obj
def _coerce_numeric_value(value: Any, target_type: Any) -> Any:
"""根据配置字段类型,把旧 WebUI 可能写入的数字字符串还原为数字。"""
if target_type is str:
if isinstance(value, (int, float)):
return str(value)
return value
if target_type is int:
if isinstance(value, str):
try:
parsed_value = float(value.strip())
except ValueError:
return value
if parsed_value.is_integer():
return int(parsed_value)
return value
if target_type is float:
if isinstance(value, str):
try:
return float(value.strip())
except ValueError:
return value
return value
return value
def _coerce_value_by_annotation(value: Any, annotation: Any) -> Any:
"""递归按 ConfigBase 字段注解修正数据类型,避免保存时把数字写成字符串。"""
value = _coerce_numeric_value(value, annotation)
origin = get_origin(annotation)
args = get_args(annotation)
if origin in {Union, types.UnionType}:
for candidate_type in args:
if candidate_type is type(None):
continue
coerced_value = _coerce_value_by_annotation(value, candidate_type)
if coerced_value != value or type(coerced_value) is not type(value):
return coerced_value
return value
if origin in {list, List} and isinstance(value, list) and args:
item_type = args[0]
return [_coerce_value_by_annotation(item, item_type) for item in value]
if origin in {dict, Dict} and isinstance(value, dict) and len(args) >= 2:
value_type = args[1]
return {key: _coerce_value_by_annotation(item, value_type) for key, item in value.items()}
if isinstance(value, dict) and isinstance(annotation, type) and issubclass(annotation, ConfigBase):
return _coerce_config_numeric_values(value, annotation)
return value
def _coerce_config_numeric_values(data: Dict[str, Any], config_type: type[ConfigBase]) -> Dict[str, Any]:
"""按配置类 schema 统一修正所有数字字段类型。"""
for field_name, field_info in config_type.model_fields.items():
if field_name in data:
data[field_name] = _coerce_value_by_annotation(data[field_name], field_info.annotation)
return data
# ===== 架构获取接口 =====
@@ -147,14 +217,20 @@ async def list_prompt_files():
continue
language = language_dir.name
prompt_template_infos = list_prompt_templates(locale=language, prompts_root=PROMPTS_DIR)
prompt_files: List[PromptFileInfo] = []
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
stat = prompt_file.stat()
template_info = prompt_template_infos.get(prompt_file.stem)
metadata = template_info.metadata if template_info and template_info.path == prompt_file else None
prompt_files.append(
PromptFileInfo(
name=prompt_file.name,
size=stat.st_size,
modified_at=stat.st_mtime,
display_name=metadata.display_name if metadata else "",
advanced=metadata.advanced if metadata else False,
description=metadata.description if metadata else "",
)
)
@@ -347,6 +423,8 @@ async def get_model_config():
async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置"""
try:
config_data = _coerce_config_numeric_values(config_data, Config)
# 验证配置数据
try:
Config.from_dict(AttributeData(), copy.deepcopy(config_data))
@@ -370,6 +448,8 @@ async def update_bot_config(config_data: ConfigBody):
async def update_model_config(config_data: ConfigBody):
"""更新模型配置"""
try:
config_data = _coerce_config_numeric_values(config_data, ModelConfig)
# 验证配置数据
try:
ModelConfig.from_dict(AttributeData(), copy.deepcopy(config_data))
@@ -422,10 +502,13 @@ async def update_bot_config_section(section_name: str, section_data: SectionBody
# 验证完整配置
try:
Config.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), Config)
Config.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
config_data = plain_config_data
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
@@ -520,13 +603,14 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
# 验证完整配置
try:
ModelConfig.from_dict(AttributeData(), _toml_to_plain_dict(config_data))
plain_config_data = _coerce_config_numeric_values(_toml_to_plain_dict(config_data), ModelConfig)
ModelConfig.from_dict(AttributeData(), copy.deepcopy(plain_config_data))
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
models = plain_config_data.get("models", [])
orphaned_models: List[str] = [
str(model_name)
for m in models
@@ -539,6 +623,8 @@ async def update_model_config_section(section_name: str, section_data: SectionBo
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
config_data = plain_config_data
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)

View File

@@ -10,11 +10,12 @@ from sqlmodel import col, delete, select
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.database.database_model import ChatSession, Expression, Messages, ModifiedBy
from src.common.logger import get_logger
from src.webui.dependencies import require_auth
logger = get_logger("webui.expression")
EXCLUDE_IDS_QUERY = Query(None, description="需要排除的表达方式 ID")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"], dependencies=[Depends(require_auth)])
@@ -28,6 +29,7 @@ class ExpressionResponse(BaseModel):
style: str
last_active_time: float
chat_id: str
chat_name: Optional[str] = None
create_date: Optional[float]
checked: bool
rejected: bool
@@ -90,7 +92,61 @@ class ExpressionCreateResponse(BaseModel):
data: ExpressionResponse
def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name_from_latest_message(chat_id: str, db_session: Any) -> Optional[str]:
"""从最近消息中解析聊天显示名称。"""
statement = (
select(Messages)
.where(col(Messages.session_id) == chat_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
message = db_session.exec(statement).first()
if not message:
return None
if message.group_id:
return message.group_name or f"群聊{message.group_id}"
return message.user_cardname or message.user_nickname or (f"用户{message.user_id}" if message.user_id else None)
def get_chat_name_from_session_record(chat_session: ChatSession) -> str:
"""从会话记录推断兜底显示名称。"""
if chat_session.group_id:
return f"群聊{chat_session.group_id}"
if chat_session.user_id:
return f"用户{chat_session.user_id}"
return chat_session.session_id
def get_chat_name(chat_id: str, db_session: Optional[Any] = None) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
db_session: 可选数据库会话,用于从历史消息中解析群名或私聊用户名。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
if name := _chat_manager.get_session_name(chat_id):
return name
if db_session and (name := get_chat_name_from_latest_message(chat_id, db_session)):
return name
session = _chat_manager.get_session_by_session_id(chat_id)
if session:
if session.group_id:
return f"群聊{session.group_id}"
if session.user_id:
return f"用户{session.user_id}"
return chat_id
except Exception:
return chat_id
def expression_to_response(expression: Expression, db_session: Optional[Any] = None) -> ExpressionResponse:
"""将表达方式模型转换为响应对象。
Args:
@@ -101,38 +157,21 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
chat_id = expression.session_id or ""
return ExpressionResponse(
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
chat_id=chat_id,
chat_name=get_chat_name(chat_id, db_session) if chat_id else None,
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by.value if expression.modified_by else None,
)
def get_chat_name(chat_id: str) -> str:
"""根据聊天 ID 获取聊天名称。
Args:
chat_id: 聊天会话 ID。
Returns:
str: 聊天显示名称,获取失败时返回原始聊天 ID。
"""
try:
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
return chat_id
name = _chat_manager.get_session_name(chat_id)
return name or chat_id
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称。
@@ -145,8 +184,7 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
for chat_id in chat_ids:
if name := _chat_manager.get_session_name(chat_id):
result[chat_id] = name
result[chat_id] = get_chat_name(chat_id)
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -176,19 +214,43 @@ async def get_chat_list() -> ChatListResponse:
ChatListResponse: 可用于下拉选择的聊天列表。
"""
try:
chat_list = []
chat_by_id: Dict[str, ChatInfo] = {}
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_list.append(
ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=chat_name,
platform=session.platform,
is_group=session.is_group_session,
)
with get_db_session() as session:
for chat_session in session.exec(select(ChatSession)).all():
if chat_session.session_id in chat_by_id:
continue
chat_name = get_chat_name_from_latest_message(chat_session.session_id, session)
chat_by_id[chat_session.session_id] = ChatInfo(
chat_id=chat_session.session_id,
chat_name=chat_name or get_chat_name_from_session_record(chat_session),
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
expression_chat_ids = {
chat_id for chat_id in session.exec(select(Expression.session_id)).all() if chat_id
}
for session_id in expression_chat_ids:
if session_id in chat_by_id:
continue
chat_by_id[session_id] = ChatInfo(
chat_id=session_id,
chat_name=get_chat_name(session_id, session),
platform=None,
is_group=False,
)
# 按名称排序
chat_list = list(chat_by_id.values())
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
@@ -252,7 +314,7 @@ async def get_expression_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@@ -281,7 +343,7 @@ async def get_expression_detail(expression_id: int) -> ExpressionDetailResponse:
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
data = expression_to_response(expression)
data = expression_to_response(expression, session)
return ExpressionDetailResponse(success=True, data=data)
@@ -321,7 +383,7 @@ async def create_expression(
session.add(expression)
session.flush()
expression_id = expression.id
data = expression_to_response(expression)
data = expression_to_response(expression, session)
logger.info(f"表达方式已创建: ID={expression_id}, situation={request.situation}")
@@ -375,7 +437,7 @@ async def update_expression(
db_expression.session_id = update_data["session_id"]
db_expression.last_active_time = update_data["last_active_time"]
session.add(db_expression)
data = expression_to_response(db_expression)
data = expression_to_response(db_expression, session)
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@@ -524,6 +586,22 @@ class ReviewStatsResponse(BaseModel):
user_checked: int
def apply_review_filter(statement: Any, filter_type: str) -> Any:
"""按审核状态过滤表达方式查询。"""
if filter_type == "unchecked":
return statement.where(col(Expression.checked).is_(False))
if filter_type == "passed":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(False))
if filter_type == "rejected":
return statement.where(col(Expression.checked).is_(True), col(Expression.rejected).is_(True))
return statement
def count_expressions(session: Any, statement: Any) -> int:
"""统计表达方式查询结果数量。"""
return len(session.exec(statement).all())
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats() -> ReviewStatsResponse:
"""获取审核统计数据。
@@ -533,12 +611,24 @@ async def get_review_stats() -> ReviewStatsResponse:
"""
try:
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
total = count_expressions(session, select(Expression.id))
unchecked = count_expressions(session, apply_review_filter(select(Expression.id), "unchecked"))
passed = count_expressions(session, apply_review_filter(select(Expression.id), "passed"))
rejected = count_expressions(session, apply_review_filter(select(Expression.id), "rejected"))
ai_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.AI,
),
)
user_checked = count_expressions(
session,
select(Expression.id).where(
col(Expression.checked).is_(True),
col(Expression.modified_by) == ModifiedBy.USER,
),
)
return ReviewStatsResponse(
total=total,
@@ -571,8 +661,10 @@ async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
order: str = Query("latest", description="排序方式: latest/random"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
exclude_ids: Optional[List[int]] = EXCLUDE_IDS_QUERY,
) -> ReviewListResponse:
"""获取待审核或已审核的表达方式列表。
@@ -580,17 +672,16 @@ async def get_review_list(
page: 页码。
page_size: 每页数量。
filter_type: 筛选类型,可选 unchecked、passed、rejected 或 all。
order: 排序方式,可选 latest 或 random。
search: 搜索关键词。
chat_id: 聊天 ID 筛选条件。
exclude_ids: 需要排除的表达方式 ID。
Returns:
ReviewListResponse: 审核列表响应。
"""
try:
statement = select(Expression)
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
statement = apply_review_filter(select(Expression), filter_type)
# all 不需要额外过滤
# 搜索过滤
@@ -603,11 +694,17 @@ async def get_review_list(
if chat_id:
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
if exclude_ids:
statement = statement.where(~col(Expression.id).in_(exclude_ids))
if order == "random":
statement = statement.order_by(func.random())
else:
# 排序:创建时间倒序
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
offset = (page - 1) * page_size
statement = statement.offset(offset).limit(page_size)
@@ -615,9 +712,7 @@ async def get_review_list(
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
count_statement = apply_review_filter(select(Expression.id), filter_type)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
@@ -625,7 +720,7 @@ async def get_review_list(
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
data = [expression_to_response(expr) for expr in expressions]
data = [expression_to_response(expr, session) for expr in expressions]
return ReviewListResponse(
success=True,
@@ -647,7 +742,7 @@ class BatchReviewItem(BaseModel):
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
require_unchecked: bool = True # 前端保留的来源标记,人工审核提交时不再阻断覆盖
class BatchReviewRequest(BaseModel):
@@ -706,14 +801,6 @@ async def batch_review_expressions(
failed += 1
continue
# 冲突检测
if item.require_unchecked:
results.append(
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
)
failed += 1
continue
# 更新状态
with get_db_session() as session:
db_expression = session.exec(
@@ -727,6 +814,9 @@ async def batch_review_expressions(
)
failed += 1
continue
db_expression.checked = True
db_expression.rejected = item.rejected
db_expression.modified_by = ModifiedBy.USER
db_expression.last_active_time = datetime.now()
session.add(db_expression)