merge: 同步上游 dev 并增强人物画像查询
This commit is contained in:
@@ -74,7 +74,7 @@
|
||||
"enabled": {
|
||||
"name": "enabled",
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"default": false,
|
||||
"description": "是否启用 A_Memorix",
|
||||
"label": "启用 A_Memorix",
|
||||
"ui_type": "switch",
|
||||
@@ -82,7 +82,7 @@
|
||||
"hidden": false,
|
||||
"disabled": false,
|
||||
"order": 1,
|
||||
"hint": "关闭后 A_Memorix 不会参与长期记忆写入、检索与运维。",
|
||||
"hint": "默认关闭以简化首次配置;开启前请先配置可用的 embedding 模型。关闭后 A_Memorix 不会参与长期记忆写入、检索与运维。",
|
||||
"choices": null
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""SDK runtime exports for A_Memorix."""
|
||||
|
||||
from .search_runtime_initializer import (
|
||||
SearchRuntimeBundle,
|
||||
SearchRuntimeInitializer,
|
||||
build_search_runtime,
|
||||
)
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .search_runtime_initializer import SearchRuntimeBundle, SearchRuntimeInitializer, build_search_runtime
|
||||
|
||||
__all__ = [
|
||||
"SearchRuntimeBundle",
|
||||
@@ -14,3 +13,14 @@ __all__ = [
|
||||
"KernelSearchRequest",
|
||||
"SDKMemoryKernel",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in {"KernelSearchRequest", "SDKMemoryKernel"}:
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
return {
|
||||
"KernelSearchRequest": KernelSearchRequest,
|
||||
"SDKMemoryKernel": SDKMemoryKernel,
|
||||
}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -3259,7 +3259,6 @@ class ImportTaskManager:
|
||||
for task_name in [
|
||||
"lpmm_entity_extract",
|
||||
"lpmm_rdf_build",
|
||||
"embedding",
|
||||
"replyer",
|
||||
"utils",
|
||||
"planner",
|
||||
|
||||
@@ -4,20 +4,35 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
from src.config.official_configs import AMemorixConfig
|
||||
from src.webui.utils.toml_utils import _update_toml_doc
|
||||
|
||||
from .core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from .paths import config_path, repo_root, schema_path
|
||||
from .paths import repo_root, schema_path
|
||||
from .runtime_registry import set_runtime_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
|
||||
logger = get_logger("a_memorix.host_service")
|
||||
|
||||
|
||||
def _get_config_manager():
|
||||
from src.config.config import config_manager
|
||||
|
||||
return config_manager
|
||||
|
||||
|
||||
def _get_bot_config_path() -> Path:
|
||||
from src.config.config import BOT_CONFIG_PATH
|
||||
|
||||
return BOT_CONFIG_PATH
|
||||
|
||||
|
||||
def _to_builtin_data(obj: Any) -> Any:
|
||||
if hasattr(obj, "unwrap"):
|
||||
try:
|
||||
@@ -46,8 +61,12 @@ class AMemorixHostService:
|
||||
self._lock = asyncio.Lock()
|
||||
self._kernel: Optional[SDKMemoryKernel] = None
|
||||
self._config_cache: Dict[str, Any] | None = None
|
||||
self._reload_callback_registered = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if not self.is_enabled():
|
||||
logger.info("A_Memorix 未启用,跳过长期记忆运行时初始化")
|
||||
return
|
||||
await self._ensure_kernel()
|
||||
|
||||
async def stop(self) -> None:
|
||||
@@ -57,12 +76,16 @@ class AMemorixHostService:
|
||||
async def reload(self) -> None:
|
||||
async with self._lock:
|
||||
await self._shutdown_locked()
|
||||
self._config_cache = self._read_config()
|
||||
self._config_cache = None
|
||||
config = self._read_config()
|
||||
|
||||
await self._ensure_kernel()
|
||||
if self._is_enabled_config(config):
|
||||
await self._ensure_kernel()
|
||||
else:
|
||||
logger.info("A_Memorix 配置为未启用,运行时保持关闭")
|
||||
|
||||
def get_config_path(self) -> Path:
|
||||
return config_path()
|
||||
return _get_bot_config_path()
|
||||
|
||||
def get_schema_path(self) -> Path:
|
||||
return schema_path()
|
||||
@@ -88,54 +111,28 @@ class AMemorixHostService:
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
return dict(self._read_config())
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
return self._is_enabled_config(self._read_config())
|
||||
|
||||
@staticmethod
|
||||
def _is_enabled_config(config: Dict[str, Any]) -> bool:
|
||||
plugin_config = config.get("plugin") if isinstance(config, dict) else None
|
||||
if not isinstance(plugin_config, dict):
|
||||
return True
|
||||
return bool(plugin_config.get("enabled", True))
|
||||
|
||||
def _build_default_config(self) -> Dict[str, Any]:
|
||||
schema = self.get_config_schema()
|
||||
sections = schema.get("sections") if isinstance(schema, dict) else None
|
||||
if not isinstance(sections, dict):
|
||||
return {}
|
||||
|
||||
defaults: Dict[str, Any] = {}
|
||||
for section_name, section_payload in sections.items():
|
||||
if not isinstance(section_payload, dict):
|
||||
continue
|
||||
fields = section_payload.get("fields")
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
|
||||
section_parts = [part for part in str(section_name or "").split(".") if part]
|
||||
if not section_parts:
|
||||
continue
|
||||
|
||||
section_target: Dict[str, Any] = defaults
|
||||
for part in section_parts:
|
||||
nested = section_target.get(part)
|
||||
if not isinstance(nested, dict):
|
||||
nested = {}
|
||||
section_target[part] = nested
|
||||
section_target = nested
|
||||
|
||||
for field_name, field_payload in fields.items():
|
||||
if not isinstance(field_payload, dict) or "default" not in field_payload:
|
||||
continue
|
||||
section_target[str(field_name)] = _to_builtin_data(field_payload.get("default"))
|
||||
|
||||
return defaults
|
||||
return self._config_model_to_runtime_dict(AMemorixConfig())
|
||||
|
||||
def get_raw_config_with_meta(self) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
if path.exists():
|
||||
return {
|
||||
"config": path.read_text(encoding="utf-8"),
|
||||
"exists": True,
|
||||
"using_default": False,
|
||||
}
|
||||
|
||||
config = self.get_config()
|
||||
default_config = self._build_default_config()
|
||||
default_raw = tomlkit.dumps(default_config) if default_config else ""
|
||||
raw_doc = tomlkit.document()
|
||||
raw_doc.add("a_memorix", config)
|
||||
return {
|
||||
"config": default_raw,
|
||||
"exists": False,
|
||||
"using_default": True,
|
||||
"config": tomlkit.dumps(raw_doc),
|
||||
"exists": self.get_config_path().exists(),
|
||||
"using_default": config == default_config,
|
||||
}
|
||||
|
||||
def get_raw_config(self) -> str:
|
||||
@@ -143,12 +140,10 @@ class AMemorixHostService:
|
||||
return str(payload.get("config", "") or "")
|
||||
|
||||
async def update_raw_config(self, raw_config: str) -> Dict[str, Any]:
|
||||
tomlkit.loads(raw_config)
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
path.write_text(raw_config, encoding="utf-8")
|
||||
await self.reload()
|
||||
loaded = tomlkit.loads(raw_config)
|
||||
raw_payload = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
|
||||
config_payload = raw_payload.get("a_memorix") if isinstance(raw_payload.get("a_memorix"), dict) else raw_payload
|
||||
path, backup_path = await self._write_config_to_bot_config(config_payload)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
@@ -157,11 +152,7 @@ class AMemorixHostService:
|
||||
}
|
||||
|
||||
async def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
save_toml_with_format(config, str(path), preserve_comments=True)
|
||||
await self.reload()
|
||||
path, backup_path = await self._write_config_to_bot_config(config)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
@@ -172,9 +163,13 @@ class AMemorixHostService:
|
||||
async def invoke(self, component_name: str, args: Dict[str, Any] | None = None, *, timeout_ms: int = 30000) -> Any:
|
||||
del timeout_ms
|
||||
payload = args or {}
|
||||
if not self.is_enabled():
|
||||
return self._disabled_response(component_name)
|
||||
kernel = await self._ensure_kernel()
|
||||
|
||||
if component_name == "search_memory":
|
||||
from .core.runtime.sdk_memory_kernel import KernelSearchRequest
|
||||
|
||||
return await kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
@@ -278,7 +273,11 @@ class AMemorixHostService:
|
||||
async def _ensure_kernel(self) -> SDKMemoryKernel:
|
||||
async with self._lock:
|
||||
if self._kernel is None:
|
||||
from .core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
|
||||
config = self._read_config()
|
||||
if not self._is_enabled_config(config):
|
||||
raise RuntimeError("A_Memorix 未启用")
|
||||
kernel = SDKMemoryKernel(plugin_root=repo_root(), config=config)
|
||||
try:
|
||||
await kernel.initialize()
|
||||
@@ -293,24 +292,149 @@ class AMemorixHostService:
|
||||
if self._config_cache is not None:
|
||||
return dict(self._config_cache)
|
||||
|
||||
path = self.get_config_path()
|
||||
if not path.exists():
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
loaded = tomlkit.load(handle)
|
||||
config_model = _get_config_manager().get_global_config().a_memorix
|
||||
except Exception as exc:
|
||||
logger.warning("读取 A_Memorix 配置失败 %s: %s", path, exc)
|
||||
logger.warning("读取 A_Memorix 主配置失败,使用默认值: %s", exc)
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
self._config_cache = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
|
||||
self._config_cache = self._config_model_to_runtime_dict(config_model)
|
||||
return dict(self._config_cache)
|
||||
|
||||
@staticmethod
|
||||
def _config_model_to_runtime_dict(config_model: AMemorixConfig) -> Dict[str, Any]:
|
||||
payload = config_model.model_dump(mode="json")
|
||||
web_config = payload.get("web")
|
||||
if isinstance(web_config, dict) and "import_config" in web_config:
|
||||
web_config["import"] = web_config.pop("import_config")
|
||||
return _to_builtin_data(payload) if isinstance(payload, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _runtime_dict_to_bot_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
payload = _to_builtin_data(config)
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
web_config = payload.get("web")
|
||||
if isinstance(web_config, dict) and "import_config" in web_config and "import" not in web_config:
|
||||
web_config["import"] = web_config.pop("import_config")
|
||||
return payload
|
||||
|
||||
async def _write_config_to_bot_config(self, config: Dict[str, Any]) -> tuple[Path, Optional[Path]]:
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
doc = tomlkit.load(handle)
|
||||
else:
|
||||
doc = tomlkit.document()
|
||||
|
||||
bot_config_payload = self._runtime_dict_to_bot_config_dict(config)
|
||||
current = doc.get("a_memorix")
|
||||
if isinstance(current, dict):
|
||||
_update_toml_doc(current, bot_config_payload)
|
||||
else:
|
||||
doc["a_memorix"] = bot_config_payload
|
||||
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
tomlkit.dump(doc, handle)
|
||||
|
||||
await _get_config_manager().reload_config(changed_scopes=("bot",))
|
||||
if not self._reload_callback_registered:
|
||||
await self.reload()
|
||||
return path, backup_path
|
||||
|
||||
def register_config_reload_callback(self) -> None:
|
||||
if self._reload_callback_registered:
|
||||
return
|
||||
_get_config_manager().register_reload_callback(self.on_config_reload)
|
||||
self._reload_callback_registered = True
|
||||
|
||||
async def on_config_reload(self, changed_scopes: Sequence[str] | None = None) -> None:
|
||||
normalized = {str(scope or "").strip().lower() for scope in (changed_scopes or [])}
|
||||
if normalized and "bot" not in normalized:
|
||||
return
|
||||
await self.reload()
|
||||
|
||||
@staticmethod
|
||||
def _disabled_response(component_name: str) -> Dict[str, Any]:
|
||||
reason = "a_memorix_disabled"
|
||||
message = "A_Memorix 未启用,请在长期记忆配置中开启后再使用。"
|
||||
|
||||
if component_name == "search_memory":
|
||||
return {
|
||||
"success": True,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"summary": "",
|
||||
"hits": [],
|
||||
"filtered": False,
|
||||
}
|
||||
|
||||
if component_name in {"ingest_summary", "ingest_text"}:
|
||||
return {
|
||||
"success": True,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"stored_ids": [],
|
||||
"skipped_ids": [reason],
|
||||
"detail": reason,
|
||||
}
|
||||
|
||||
if component_name == "get_person_profile":
|
||||
return {
|
||||
"success": True,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"summary": "",
|
||||
"traits": [],
|
||||
"evidence": [],
|
||||
}
|
||||
|
||||
if component_name == "memory_stats":
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": False,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"message": message,
|
||||
"paragraph_count": 0,
|
||||
"relation_count": 0,
|
||||
"episode_count": 0,
|
||||
}
|
||||
|
||||
if component_name == "memory_runtime_admin":
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": False,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"message": message,
|
||||
"runtime_ready": False,
|
||||
"embedding_degraded": False,
|
||||
"embedding_dimension": 0,
|
||||
"auto_save": False,
|
||||
"data_dir": "",
|
||||
}
|
||||
|
||||
if component_name == "enqueue_feedback_task":
|
||||
return {
|
||||
"success": True,
|
||||
"queued": False,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"enabled": False,
|
||||
"disabled": True,
|
||||
"reason": reason,
|
||||
"error": message,
|
||||
}
|
||||
|
||||
async def _shutdown_locked(self) -> None:
|
||||
if self._kernel is None:
|
||||
return
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
_NATIVE_THREAD_ENV_DEFAULTS = {
|
||||
"OMP_NUM_THREADS": "1",
|
||||
"OPENBLAS_NUM_THREADS": "1",
|
||||
"MKL_NUM_THREADS": "1",
|
||||
"NUMEXPR_NUM_THREADS": "1",
|
||||
}
|
||||
|
||||
for _name, _value in _NATIVE_THREAD_ENV_DEFAULTS.items():
|
||||
os.environ.setdefault(_name, _value)
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.data_models.image_data_model import MaiImage
|
||||
from src.config.config import config_manager
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
@@ -30,6 +31,17 @@ def _ensure_image_dir_exists() -> None:
|
||||
IMAGE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _is_vlm_task_configured() -> bool:
|
||||
"""判断是否配置了可用于图片识别的视觉模型任务。"""
|
||||
|
||||
try:
|
||||
vlm_models = config_manager.get_model_config().model_task_config.vlm.model_list
|
||||
return any(str(model_name).strip() for model_name in vlm_models)
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取 VLM 模型配置失败,跳过图片识别: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
vlm = LLMServiceClient(task_name="vlm", request_type="image")
|
||||
|
||||
|
||||
@@ -111,6 +123,9 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件时发生错误: {e}")
|
||||
return ""
|
||||
if not _is_vlm_task_configured():
|
||||
logger.info("未配置 VLM 模型,跳过图片识别")
|
||||
return ""
|
||||
if not wait_for_build:
|
||||
self._schedule_description_build(hash_str, image_bytes)
|
||||
return ""
|
||||
@@ -129,6 +144,10 @@ class ImageManager:
|
||||
image_hash: 图片哈希值。
|
||||
image_bytes: 图片字节数据。
|
||||
"""
|
||||
if not _is_vlm_task_configured():
|
||||
logger.info("未配置 VLM 模型,跳过图片后台识别任务")
|
||||
return
|
||||
|
||||
if image_hash in self._pending_description_tasks:
|
||||
return
|
||||
|
||||
@@ -303,6 +322,9 @@ class ImageManager:
|
||||
await mai_image.calculate_hash_format()
|
||||
if mai_image.vlm_processed and mai_image.description:
|
||||
return mai_image
|
||||
if not _is_vlm_task_configured():
|
||||
logger.info(f"未配置 VLM 模型,跳过图片识别: {mai_image.file_hash}")
|
||||
return mai_image
|
||||
|
||||
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
|
||||
mai_image.description = desc
|
||||
|
||||
@@ -245,7 +245,7 @@ class SessionMessage(MaiMessage):
|
||||
except Exception:
|
||||
desc = None # 失败置空
|
||||
|
||||
content = f"[图片:{desc}]" if desc else "[图片]"
|
||||
content = f"[图片:{desc}]" if desc else ""
|
||||
component.content = content
|
||||
component.binary_data = b"" # 处理完就丢掉二进制数据,节省内存
|
||||
return content
|
||||
|
||||
@@ -174,7 +174,7 @@ class BaseMaisakaReplyGenerator:
|
||||
continue
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
rendered_parts.append(component.content.strip() or "[图片]")
|
||||
rendered_parts.append(component.content.strip() or "[图片,识别中.....]")
|
||||
continue
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
|
||||
@@ -348,7 +348,7 @@ class MessageSequence:
|
||||
if isinstance(item, TextComponent):
|
||||
return {"type": "text", "data": item.text}
|
||||
elif isinstance(item, ImageComponent):
|
||||
return {"type": "image", "data": self._ensure_binary_component_content(item, "[图片]"), "hash": item.binary_hash}
|
||||
return {"type": "image", "data": item.content.strip(), "hash": item.binary_hash}
|
||||
elif isinstance(item, EmojiComponent):
|
||||
return {"type": "emoji", "data": self._ensure_binary_component_content(item, "[表情包]"), "hash": item.binary_hash}
|
||||
elif isinstance(item, VoiceComponent):
|
||||
@@ -387,10 +387,8 @@ class MessageSequence:
|
||||
"""确保二进制组件在序列化时带有稳定的文本占位。"""
|
||||
normalized_content = item.content.strip()
|
||||
if normalized_content:
|
||||
item.content = normalized_content
|
||||
return item.content
|
||||
item.content = fallback_text
|
||||
return item.content
|
||||
return normalized_content
|
||||
return fallback_text
|
||||
|
||||
@classmethod
|
||||
def _dict_2_item(cls, item: Dict[str, Any]) -> StandardMessageComponents:
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any, Callable, Mapping, Sequence, TypeVar, cast
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
import tomlkit
|
||||
|
||||
@@ -16,6 +15,7 @@ from .file_watcher import FileChange, FileWatcher
|
||||
from .legacy_migration import migrate_legacy_bind_env_to_bot_config_dict, try_migrate_legacy_bot_config_dict
|
||||
from .model_configs import APIProvider, ModelInfo, ModelTaskConfig
|
||||
from .official_configs import (
|
||||
AMemorixConfig,
|
||||
BotConfig,
|
||||
ChatConfig,
|
||||
ChineseTypoConfig,
|
||||
@@ -55,9 +55,10 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.9.20"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.5"
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0-pre.10"
|
||||
CONFIG_VERSION: str = "8.10.6"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.8"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -86,6 +87,9 @@ class Config(ConfigBase):
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig)
|
||||
"""记忆配置类"""
|
||||
|
||||
a_memorix: AMemorixConfig = Field(default_factory=AMemorixConfig)
|
||||
"""A_Memorix 长期记忆子系统配置"""
|
||||
|
||||
message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
|
||||
"""消息接收配置类"""
|
||||
|
||||
@@ -176,9 +180,50 @@ class ModelConfig(ConfigBase):
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
def _normalize_a_memorix_legacy_config(config_data: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = copy.deepcopy(config_data)
|
||||
web_config = normalized.get("web")
|
||||
if isinstance(web_config, dict) and "import" in web_config and "import_config" not in web_config:
|
||||
web_config["import_config"] = web_config.pop("import")
|
||||
return normalized
|
||||
|
||||
|
||||
def _migrate_legacy_a_memorix_config(config_data: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||
if isinstance(config_data.get("a_memorix"), dict):
|
||||
return config_data, False
|
||||
if not A_MEMORIX_LEGACY_CONFIG_PATH.exists():
|
||||
return config_data, False
|
||||
|
||||
try:
|
||||
with A_MEMORIX_LEGACY_CONFIG_PATH.open("r", encoding="utf-8") as handle:
|
||||
legacy_data = tomlkit.load(handle).unwrap()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取旧版 A_Memorix 配置失败,已使用主配置默认值: {A_MEMORIX_LEGACY_CONFIG_PATH},原因: {exc}")
|
||||
return config_data, False
|
||||
|
||||
if not isinstance(legacy_data, dict):
|
||||
logger.warning(f"旧版 A_Memorix 配置内容无效,已使用主配置默认值: {A_MEMORIX_LEGACY_CONFIG_PATH}")
|
||||
return config_data, False
|
||||
|
||||
migrated_data = copy.deepcopy(config_data)
|
||||
migrated_data["a_memorix"] = _normalize_a_memorix_legacy_config(legacy_data)
|
||||
logger.warning(f"检测到旧版 A_Memorix 配置,已迁移到 bot_config.toml 的 [a_memorix]: {A_MEMORIX_LEGACY_CONFIG_PATH}")
|
||||
return migrated_data, True
|
||||
|
||||
|
||||
def _normalize_loaded_bot_config_dict(config_data: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = copy.deepcopy(config_data)
|
||||
a_memorix_config = normalized.get("a_memorix")
|
||||
if isinstance(a_memorix_config, dict):
|
||||
normalized["a_memorix"] = _normalize_a_memorix_legacy_config(a_memorix_config)
|
||||
return normalized
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""总配置管理类"""
|
||||
|
||||
VLM_NOT_CONFIGURED_WARNING: str = "未配置视觉识图模型,部分图片理解可能受限,请在webui或model_config中配置"
|
||||
|
||||
def __init__(self):
|
||||
self.bot_config_path: Path = BOT_CONFIG_PATH
|
||||
self.model_config_path: Path = MODEL_CONFIG_PATH
|
||||
@@ -204,19 +249,26 @@ class ConfigManager:
|
||||
True,
|
||||
)
|
||||
if global_updated or model_updated:
|
||||
sys.exit(0) # 配置已自动升级,退出一次让用户确认新配置后再启动
|
||||
logger.info("配置已自动升级,将继续使用更新后的配置启动")
|
||||
self._warn_if_vlm_not_configured(self.model_config)
|
||||
logger.info(t("config.loaded"))
|
||||
|
||||
@classmethod
|
||||
def _warn_if_vlm_not_configured(cls, model_config: ModelConfig) -> None:
|
||||
if any(model_name.strip() for model_name in model_config.model_task_config.vlm.model_list):
|
||||
return
|
||||
logger.warning(cls.VLM_NOT_CONFIGURED_WARNING)
|
||||
|
||||
def load_global_config(self) -> Config:
|
||||
config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
logger.info("bot_config.toml 已自动升级,将继续使用更新后的配置")
|
||||
return config
|
||||
|
||||
def load_model_config(self) -> ModelConfig:
|
||||
config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
logger.info("model_config.toml 已自动升级,将继续使用更新后的配置")
|
||||
return config
|
||||
|
||||
def get_global_config(self) -> Config:
|
||||
@@ -498,6 +550,7 @@ def load_config_from_file(
|
||||
raise TypeError(t("config.invalid_inner_version"))
|
||||
old_ver: str = inner_version
|
||||
env_migration_applied: bool = False
|
||||
a_memorix_migration_applied: bool = False
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
@@ -510,6 +563,8 @@ def load_config_from_file(
|
||||
if legacy_migration.migrated:
|
||||
logger.warning(t("config.legacy_migrated", reason=legacy_migration.reason))
|
||||
config_data = legacy_migration.data
|
||||
config_data, a_memorix_migration_applied = _migrate_legacy_a_memorix_config(config_data)
|
||||
config_data = _normalize_loaded_bot_config_dict(config_data)
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||||
try:
|
||||
@@ -529,7 +584,7 @@ def load_config_from_file(
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied:
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied or a_memorix_migration_applied:
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
if env_migration_applied:
|
||||
@@ -578,6 +633,14 @@ def write_config_to_file(
|
||||
else:
|
||||
raise TypeError(t("config.write_unsupported_type"))
|
||||
|
||||
if isinstance(config, Config):
|
||||
try:
|
||||
a_memorix_web = full_config_data["a_memorix"]["web"]
|
||||
if "import_config" in a_memorix_web and "import" not in a_memorix_web:
|
||||
a_memorix_web["import"] = a_memorix_web.pop("import_config")
|
||||
except Exception:
|
||||
logger.debug("A_Memorix 配置写出时转换 web.import_config 失败", exc_info=True)
|
||||
|
||||
# 备份旧文件
|
||||
if config_path.exists():
|
||||
backup_root = config_path.parent / "old"
|
||||
|
||||
@@ -135,6 +135,7 @@ class ConfigBase(BaseModel, AttrDocBase):
|
||||
__ui_parent__: ClassVar[str] = "" # 父配置类在 Config 中的字段名,空表示独立 Tab
|
||||
__ui_label__: ClassVar[str] = "" # Tab 显示名称(仅做 Tab 主人时使用),空则使用 classDoc
|
||||
__ui_icon__: ClassVar[str] = "" # Tab 图标名称(Lucide 图标名)
|
||||
__ui_merge_children__: ClassVar[List[str]] = [] # 在 WebUI 中并入当前配置卡片展示的子配置字段名
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, attribute_data: AttributeData, data: dict[str, Any]):
|
||||
|
||||
@@ -39,27 +39,6 @@ DEFAULT_TASK_CONFIG_TEMPLATES: dict[str, dict[str, Any]] = {
|
||||
"slow_threshold": 12.0,
|
||||
"selection_strategy": "random",
|
||||
},
|
||||
"vlm": {
|
||||
"model_list": ["qwen3.5-flash"],
|
||||
"max_tokens": 512,
|
||||
"temperature": 0.3,
|
||||
"slow_threshold": 15.0,
|
||||
"selection_strategy": "random",
|
||||
},
|
||||
"voice": {
|
||||
"model_list": [""],
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
"slow_threshold": 12.0,
|
||||
"selection_strategy": "random",
|
||||
},
|
||||
"embedding": {
|
||||
"model_list": ["qwen3-embedding"],
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.3,
|
||||
"slow_threshold": 5.0,
|
||||
"selection_strategy": "random",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_TEMPLATES: list[dict[str, Any]] = [
|
||||
@@ -89,24 +68,6 @@ DEFAULT_MODEL_TEMPLATES: list[dict[str, Any]] = [
|
||||
"price_out": 2.0,
|
||||
"visual": False,
|
||||
"extra_params": {"enable_thinking": "false"},
|
||||
},
|
||||
{
|
||||
"model_identifier": "qwen3.5-flash",
|
||||
"name": "qwen3.5-flash",
|
||||
"api_provider": "BaiLian",
|
||||
"price_in": 0.2,
|
||||
"price_out": 2.0,
|
||||
"visual": True,
|
||||
"extra_params": {"enable_thinking": "false"},
|
||||
},
|
||||
{
|
||||
"model_identifier": "text-embedding-v4",
|
||||
"name": "qwen3-embedding",
|
||||
"api_provider": "BaiLian",
|
||||
"price_in": 0.5,
|
||||
"price_out": 0.5,
|
||||
"visual": False,
|
||||
"extra_params": {},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@@ -343,7 +343,12 @@ class ModelInfo(ConfigBase):
|
||||
"x-icon": "sliders",
|
||||
},
|
||||
)
|
||||
"""额外参数 (用于API调用时的额外配置)"""
|
||||
"""额外参数 (用于API调用时的额外配置)。
|
||||
OpenAI 兼容客户端会将该字典拆分为请求附加项:headers 会作为请求头传入,query 会作为 URL 查询参数传入,body 会合并到请求体。
|
||||
未放入 headers/query/body 的普通键,也会作为请求体额外字段传入;例如 {enable_thinking = "false"} 会传为请求体字段 enable_thinking。
|
||||
该字段不会以 extra_params 这个键整体发送给模型服务商。
|
||||
temperature 和 max_tokens 也可写在此处作为模型级默认值,但更推荐使用同名独立配置项。
|
||||
Gemini 客户端会按自身支持的字段筛选并映射到 GenerateContentConfig、EmbedContentConfig 或音频请求配置中。"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
if not self.model_identifier:
|
||||
@@ -431,8 +436,18 @@ class ModelTaskConfig(ConfigBase):
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""首要回复模型配置, 还用于表达器和表达方式学习"""
|
||||
"""首要回复模型配置"""
|
||||
|
||||
learner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "graduation-cap",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""学习模型配置,用于表达方式学习和黑话学习;留空时自动继用 utils 模型"""
|
||||
|
||||
planner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -27,6 +27,8 @@ class BotConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "wifi",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
},
|
||||
)
|
||||
"""平台"""
|
||||
@@ -36,6 +38,8 @@ class BotConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "user",
|
||||
"x-layout": "inline-right",
|
||||
"x-input-width": "12rem",
|
||||
},
|
||||
)
|
||||
"""QQ账号"""
|
||||
@@ -63,6 +67,7 @@ class BotConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "tags",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""别名列表"""
|
||||
@@ -101,6 +106,7 @@ class PersonalityConfig(ConfigBase):
|
||||
"带点翻译腔,但不要太长",
|
||||
],
|
||||
json_schema_extra={
|
||||
"advanced": True,
|
||||
"x-widget": "custom",
|
||||
"x-icon": "list",
|
||||
},
|
||||
@@ -108,10 +114,11 @@ class PersonalityConfig(ConfigBase):
|
||||
"""可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style"""
|
||||
|
||||
multiple_probability: float = Field(
|
||||
default=0.2,
|
||||
default=0,
|
||||
ge=0,
|
||||
le=1,
|
||||
json_schema_extra={
|
||||
"advanced": True,
|
||||
"x-widget": "slider",
|
||||
"x-icon": "percent",
|
||||
"step": 0.1,
|
||||
@@ -208,6 +215,7 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "at-sign",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否允许 replyer 使用 at[msg_id] 标记来发送真正的 at 消息"""
|
||||
@@ -217,6 +225,7 @@ class ChatConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "quote",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否启用回复时附带引用回复"""
|
||||
@@ -240,11 +249,12 @@ class ChatConfig(ConfigBase):
|
||||
"""私聊上下文长度"""
|
||||
|
||||
planner_interrupt_max_consecutive_count: int = Field(
|
||||
default=2,
|
||||
default=0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "pause-circle",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
@@ -405,6 +415,7 @@ class MemoryConfig(ConfigBase):
|
||||
)
|
||||
"""_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索"""
|
||||
|
||||
|
||||
enable_memory_query_tool: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
@@ -449,6 +460,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要的消息窗口阈值"""
|
||||
@@ -460,6 +472,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "rows-3",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""自动写回聊天摘要时,从聊天流中回看的消息条数"""
|
||||
@@ -469,6 +482,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "message-circle-warning",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否启用反馈驱动的延迟记忆纠错任务"""
|
||||
@@ -479,6 +493,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock-4",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""反馈窗口时长(小时),以 query_memory 执行时间为起点"""
|
||||
@@ -489,6 +504,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "timer",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""反馈纠错定时任务轮询间隔(分钟)"""
|
||||
@@ -500,6 +516,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-ordered",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""反馈纠错每轮最大处理任务数"""
|
||||
@@ -512,6 +529,7 @@ class MemoryConfig(ConfigBase):
|
||||
"x-widget": "slider",
|
||||
"x-icon": "gauge",
|
||||
"step": 0.01,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""自动应用纠错动作的最低置信度阈值"""
|
||||
@@ -523,6 +541,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "messages-square",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""每个纠错任务最多使用的窗口内用户反馈消息数"""
|
||||
@@ -532,6 +551,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "filter",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否启用纠错前置预筛(用于减少不必要的模型调用)"""
|
||||
@@ -541,6 +561,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "sticky-note",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否为受影响 paragraph 写入已纠正旧事实标记"""
|
||||
@@ -550,6 +571,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "eye-off",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否在用户侧查询中硬过滤带有 stale 标记的 paragraph"""
|
||||
@@ -559,6 +581,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "user-round-search",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否在反馈纠错后将受影响人物画像加入刷新队列"""
|
||||
@@ -568,6 +591,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "refresh-ccw",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""人物画像处于脏队列时,读取是否强制刷新而不直接复用旧快照"""
|
||||
@@ -577,6 +601,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "clapperboard",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否在反馈纠错后将受影响 source 加入 episode 重建队列"""
|
||||
@@ -586,6 +611,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "ban",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""episode source 处于重建队列时,是否对用户侧查询做屏蔽"""
|
||||
@@ -596,6 +622,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "repeat",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""反馈纠错二阶段一致性后台协调任务轮询间隔(分钟)"""
|
||||
@@ -607,6 +634,7 @@ class MemoryConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "list-restart",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小"""
|
||||
@@ -649,6 +677,345 @@ class MemoryConfig(ConfigBase):
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class AMemorixPluginConfig(ConfigBase):
|
||||
"""A_Memorix 子系统状态"""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
"""是否启用 A_Memorix"""
|
||||
|
||||
|
||||
class AMemorixStorageConfig(ConfigBase):
|
||||
"""A_Memorix 存储位置"""
|
||||
|
||||
data_dir: str = Field(default="data/a-memorix")
|
||||
"""数据目录"""
|
||||
|
||||
|
||||
class AMemorixEmbeddingFallbackConfig(ConfigBase):
|
||||
"""A_Memorix Embedding 回退"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用回退机制"""
|
||||
|
||||
probe_interval_seconds: int = Field(default=180, ge=10)
|
||||
"""探测间隔秒数"""
|
||||
|
||||
allow_metadata_only_write: bool = Field(default=True)
|
||||
"""是否允许仅写入元数据"""
|
||||
|
||||
|
||||
class AMemorixParagraphVectorBackfillConfig(ConfigBase):
|
||||
"""A_Memorix 段落向量回填"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用回填任务"""
|
||||
|
||||
interval_seconds: int = Field(default=60, ge=5)
|
||||
"""回填轮询间隔"""
|
||||
|
||||
batch_size: int = Field(default=64, ge=1)
|
||||
"""单批回填数量"""
|
||||
|
||||
max_retry: int = Field(default=5, ge=0)
|
||||
"""最大重试次数"""
|
||||
|
||||
|
||||
class AMemorixEmbeddingConfig(ConfigBase):
|
||||
"""A_Memorix Embedding 配置"""
|
||||
|
||||
model_name: str = Field(default="auto")
|
||||
"""Embedding 模型选择"""
|
||||
|
||||
dimension: int = Field(default=1024, ge=1)
|
||||
"""向量维度"""
|
||||
|
||||
batch_size: int = Field(default=32, ge=1)
|
||||
"""单批请求大小"""
|
||||
|
||||
max_concurrent: int = Field(default=5, ge=1)
|
||||
"""最大并发数"""
|
||||
|
||||
enable_cache: bool = Field(default=False)
|
||||
"""是否启用缓存"""
|
||||
|
||||
quantization_type: Literal["int8"] = Field(default="int8")
|
||||
"""量化方式,当前 vNext 仅支持 int8(SQ8)"""
|
||||
|
||||
fallback: AMemorixEmbeddingFallbackConfig = Field(default_factory=AMemorixEmbeddingFallbackConfig)
|
||||
"""Embedding 回退配置"""
|
||||
|
||||
paragraph_vector_backfill: AMemorixParagraphVectorBackfillConfig = Field(
|
||||
default_factory=AMemorixParagraphVectorBackfillConfig
|
||||
)
|
||||
"""段落向量回填配置"""
|
||||
|
||||
|
||||
class AMemorixSparseRetrievalConfig(ConfigBase):
|
||||
"""A_Memorix 稀疏检索配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用稀疏检索"""
|
||||
|
||||
backend: Literal["fts5"] = Field(default="fts5")
|
||||
"""稀疏检索后端"""
|
||||
|
||||
mode: Literal["auto", "fallback_only", "hybrid"] = Field(default="auto")
|
||||
"""稀疏检索模式"""
|
||||
|
||||
tokenizer_mode: Literal["jieba", "mixed", "char_2gram"] = Field(default="jieba")
|
||||
"""分词模式"""
|
||||
|
||||
candidate_k: int = Field(default=80, ge=1)
|
||||
"""段落候选数"""
|
||||
|
||||
relation_candidate_k: int = Field(default=60, ge=1)
|
||||
"""关系候选数"""
|
||||
|
||||
|
||||
class AMemorixRetrievalConfig(ConfigBase):
|
||||
"""A_Memorix 检索配置"""
|
||||
|
||||
top_k_paragraphs: int = Field(default=20, ge=1)
|
||||
"""段落候选数"""
|
||||
|
||||
top_k_relations: int = Field(default=10, ge=1)
|
||||
"""关系候选数"""
|
||||
|
||||
top_k_final: int = Field(default=10, ge=1)
|
||||
"""最终返回条数"""
|
||||
|
||||
alpha: float = Field(default=0.5, ge=0.0, le=1.0)
|
||||
"""关系融合权重"""
|
||||
|
||||
enable_ppr: bool = Field(default=True)
|
||||
"""是否启用 PPR"""
|
||||
|
||||
ppr_alpha: float = Field(default=0.85, ge=0.0, le=1.0)
|
||||
"""PPR alpha"""
|
||||
|
||||
ppr_timeout_seconds: float = Field(default=1.5, ge=0.1)
|
||||
"""PPR 超时秒数"""
|
||||
|
||||
ppr_concurrency_limit: int = Field(default=4, ge=1)
|
||||
"""PPR 并发限制"""
|
||||
|
||||
enable_parallel: bool = Field(default=True)
|
||||
"""是否启用并行检索"""
|
||||
|
||||
sparse: AMemorixSparseRetrievalConfig = Field(default_factory=AMemorixSparseRetrievalConfig)
|
||||
"""稀疏检索配置"""
|
||||
|
||||
|
||||
class AMemorixThresholdConfig(ConfigBase):
|
||||
"""A_Memorix 阈值过滤配置"""
|
||||
|
||||
min_threshold: float = Field(default=0.3, ge=0.0, le=1.0)
|
||||
"""最小阈值"""
|
||||
|
||||
max_threshold: float = Field(default=0.95, ge=0.0, le=1.0)
|
||||
"""最大阈值"""
|
||||
|
||||
percentile: int = Field(default=75, ge=0, le=100)
|
||||
"""动态阈值百分位"""
|
||||
|
||||
min_results: int = Field(default=3, ge=1)
|
||||
"""最小保留条数"""
|
||||
|
||||
enable_auto_adjust: bool = Field(default=True)
|
||||
"""是否启用自动阈值调整"""
|
||||
|
||||
|
||||
class AMemorixFilterConfig(ConfigBase):
|
||||
"""A_Memorix 聊天过滤配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用聊天过滤"""
|
||||
|
||||
mode: Literal["blacklist", "whitelist"] = Field(default="blacklist")
|
||||
"""过滤模式"""
|
||||
|
||||
chats: list[str] = Field(default_factory=lambda: [])
|
||||
"""聊天流列表"""
|
||||
|
||||
|
||||
class AMemorixEpisodeConfig(ConfigBase):
|
||||
"""A_Memorix Episode 配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用 Episode"""
|
||||
|
||||
generation_enabled: bool = Field(default=True)
|
||||
"""是否启用自动生成"""
|
||||
|
||||
pending_batch_size: int = Field(default=20, ge=1)
|
||||
"""待处理批大小"""
|
||||
|
||||
pending_max_retry: int = Field(default=3, ge=0)
|
||||
"""待处理最大重试次数"""
|
||||
|
||||
max_paragraphs_per_call: int = Field(default=20, ge=1)
|
||||
"""单次最大段落数"""
|
||||
|
||||
max_chars_per_call: int = Field(default=6000, ge=100)
|
||||
"""单次最大字符数"""
|
||||
|
||||
source_time_window_hours: float = Field(default=24.0, ge=0.0)
|
||||
"""时间窗口小时数"""
|
||||
|
||||
segmentation_model: str = Field(default="auto")
|
||||
"""分段模型选择"""
|
||||
|
||||
|
||||
class AMemorixPersonProfileConfig(ConfigBase):
|
||||
"""A_Memorix 人物画像配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用画像"""
|
||||
|
||||
refresh_interval_minutes: int = Field(default=30, ge=1)
|
||||
"""刷新间隔分钟数"""
|
||||
|
||||
active_window_hours: float = Field(default=72.0, ge=1.0)
|
||||
"""活跃窗口小时数"""
|
||||
|
||||
max_refresh_per_cycle: int = Field(default=50, ge=1)
|
||||
"""单轮最大刷新数"""
|
||||
|
||||
top_k_evidence: int = Field(default=12, ge=1)
|
||||
"""证据条数"""
|
||||
|
||||
|
||||
class AMemorixMemoryEvolutionConfig(ConfigBase):
|
||||
"""A_Memorix 记忆演化配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用记忆演化"""
|
||||
|
||||
half_life_hours: float = Field(default=24.0, ge=0.1)
|
||||
"""半衰期小时数"""
|
||||
|
||||
prune_threshold: float = Field(default=0.1, ge=0.0, le=1.0)
|
||||
"""裁剪阈值"""
|
||||
|
||||
freeze_duration_hours: float = Field(default=24.0, ge=0.0)
|
||||
"""冻结时长小时数"""
|
||||
|
||||
|
||||
class AMemorixAdvancedConfig(ConfigBase):
|
||||
"""A_Memorix 高级运行时配置"""
|
||||
|
||||
enable_auto_save: bool = Field(default=True)
|
||||
"""是否启用自动保存"""
|
||||
|
||||
auto_save_interval_minutes: int = Field(default=5, ge=1)
|
||||
"""自动保存间隔"""
|
||||
|
||||
debug: bool = Field(default=False)
|
||||
"""是否启用调试"""
|
||||
|
||||
|
||||
class AMemorixWebImportConfig(ConfigBase):
|
||||
"""A_Memorix 导入中心配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用导入中心"""
|
||||
|
||||
max_queue_size: int = Field(default=20, ge=1)
|
||||
"""最大队列长度"""
|
||||
|
||||
max_files_per_task: int = Field(default=200, ge=1)
|
||||
"""单任务最大文件数"""
|
||||
|
||||
max_file_size_mb: int = Field(default=20, ge=1)
|
||||
"""单文件大小上限 MB"""
|
||||
|
||||
max_paste_chars: int = Field(default=200000, ge=100)
|
||||
"""粘贴字符数上限"""
|
||||
|
||||
default_file_concurrency: int = Field(default=2, ge=1)
|
||||
"""默认文件并发"""
|
||||
|
||||
default_chunk_concurrency: int = Field(default=4, ge=1)
|
||||
"""默认分块并发"""
|
||||
|
||||
|
||||
class AMemorixWebTuningConfig(ConfigBase):
|
||||
"""A_Memorix 调优中心配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用调优中心"""
|
||||
|
||||
max_queue_size: int = Field(default=8, ge=1)
|
||||
"""最大队列长度"""
|
||||
|
||||
poll_interval_ms: int = Field(default=1200, ge=200)
|
||||
"""轮询间隔毫秒数"""
|
||||
|
||||
default_intensity: Literal["quick", "standard", "deep"] = Field(default="standard")
|
||||
"""默认调优强度"""
|
||||
|
||||
default_objective: Literal["precision_priority", "balanced", "recall_priority"] = Field(
|
||||
default="precision_priority"
|
||||
)
|
||||
"""默认调优目标"""
|
||||
|
||||
default_top_k_eval: int = Field(default=20, ge=1)
|
||||
"""默认评估 Top-K"""
|
||||
|
||||
default_sample_size: int = Field(default=24, ge=1)
|
||||
"""默认样本数"""
|
||||
|
||||
|
||||
class AMemorixWebConfig(ConfigBase):
|
||||
"""A_Memorix Web 运维配置"""
|
||||
|
||||
import_config: AMemorixWebImportConfig = Field(default_factory=AMemorixWebImportConfig)
|
||||
"""导入中心配置"""
|
||||
|
||||
tuning: AMemorixWebTuningConfig = Field(default_factory=AMemorixWebTuningConfig)
|
||||
"""调优中心配置"""
|
||||
|
||||
|
||||
class AMemorixConfig(ConfigBase):
|
||||
"""A_Memorix 长期记忆子系统配置"""
|
||||
|
||||
__ui_label__ = "长期记忆"
|
||||
__ui_icon__ = "brain"
|
||||
|
||||
plugin: AMemorixPluginConfig = Field(default_factory=AMemorixPluginConfig)
|
||||
"""子系统状态"""
|
||||
|
||||
storage: AMemorixStorageConfig = Field(default_factory=AMemorixStorageConfig)
|
||||
"""存储位置"""
|
||||
|
||||
embedding: AMemorixEmbeddingConfig = Field(default_factory=AMemorixEmbeddingConfig)
|
||||
"""Embedding 配置"""
|
||||
|
||||
retrieval: AMemorixRetrievalConfig = Field(default_factory=AMemorixRetrievalConfig)
|
||||
"""检索配置"""
|
||||
|
||||
threshold: AMemorixThresholdConfig = Field(default_factory=AMemorixThresholdConfig)
|
||||
"""阈值过滤配置"""
|
||||
|
||||
filter: AMemorixFilterConfig = Field(default_factory=AMemorixFilterConfig)
|
||||
"""聊天过滤配置"""
|
||||
|
||||
episode: AMemorixEpisodeConfig = Field(default_factory=AMemorixEpisodeConfig)
|
||||
"""Episode 配置"""
|
||||
|
||||
person_profile: AMemorixPersonProfileConfig = Field(default_factory=AMemorixPersonProfileConfig)
|
||||
"""人物画像配置"""
|
||||
|
||||
memory: AMemorixMemoryEvolutionConfig = Field(default_factory=AMemorixMemoryEvolutionConfig)
|
||||
"""记忆演化配置"""
|
||||
|
||||
advanced: AMemorixAdvancedConfig = Field(default_factory=AMemorixAdvancedConfig)
|
||||
"""高级运行时配置"""
|
||||
|
||||
web: AMemorixWebConfig = Field(default_factory=AMemorixWebConfig)
|
||||
"""Web 运维配置"""
|
||||
|
||||
|
||||
class LearningItem(ConfigBase):
|
||||
platform: str = Field(
|
||||
default="",
|
||||
@@ -769,19 +1136,21 @@ class ExpressionConfig(ConfigBase):
|
||||
"""是否启用自动表达优化"""
|
||||
|
||||
expression_auto_check_interval: int = Field(
|
||||
default=600,
|
||||
default=900,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""表达方式自动检查的间隔时间(秒)"""
|
||||
|
||||
expression_auto_check_count: int = Field(
|
||||
default=20,
|
||||
default=5,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""每次自动检查时随机选取的表达方式数量"""
|
||||
@@ -791,6 +1160,7 @@ class ExpressionConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "file-text",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""表达方式自动检查的额外自定义评估标准"""
|
||||
@@ -832,6 +1202,7 @@ class EmojiConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "grid",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""一次从多少个表情包中选择发送,最大为 64"""
|
||||
@@ -850,6 +1221,7 @@ class EmojiConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "refresh-cw",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""达到最大注册数量时替换旧表情包,关闭则达到最大数量时不会继续收集表情包"""
|
||||
@@ -875,6 +1247,7 @@ class EmojiConfig(ConfigBase):
|
||||
content_filtration: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"advanced": True,
|
||||
"x-widget": "switch",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
@@ -884,6 +1257,7 @@ class EmojiConfig(ConfigBase):
|
||||
filtration_prompt: str = Field(
|
||||
default="符合公序良俗",
|
||||
json_schema_extra={
|
||||
"advanced": True,
|
||||
"x-widget": "input",
|
||||
"x-icon": "shield",
|
||||
},
|
||||
@@ -973,6 +1347,7 @@ class ResponsePostProcessConfig(ConfigBase):
|
||||
|
||||
__ui_label__ = "处理"
|
||||
__ui_icon__ = "settings"
|
||||
__ui_merge_children__ = ["chinese_typo", "response_splitter"]
|
||||
|
||||
enable_response_post_process: bool = Field(
|
||||
default=True,
|
||||
@@ -1006,6 +1381,7 @@ class ChineseTypoConfig(ConfigBase):
|
||||
"x-widget": "slider",
|
||||
"x-icon": "percent",
|
||||
"step": 0.01,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""单字替换概率"""
|
||||
@@ -1015,6 +1391,7 @@ class ChineseTypoConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "hash",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""最小字频阈值"""
|
||||
@@ -1027,6 +1404,7 @@ class ChineseTypoConfig(ConfigBase):
|
||||
"x-widget": "slider",
|
||||
"x-icon": "percent",
|
||||
"step": 0.1,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""声调错误概率"""
|
||||
@@ -1039,6 +1417,7 @@ class ChineseTypoConfig(ConfigBase):
|
||||
"x-widget": "slider",
|
||||
"x-icon": "percent",
|
||||
"step": 0.001,
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""整词替换概率"""
|
||||
@@ -1081,6 +1460,7 @@ class ResponseSplitterConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "smile",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否启用颜文字保护"""
|
||||
@@ -1090,6 +1470,7 @@ class ResponseSplitterConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "maximize",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""是否在句子数量超出回复允许的最大句子数时一次性返回全部内容"""
|
||||
@@ -1098,7 +1479,7 @@ class ResponseSplitterConfig(ConfigBase):
|
||||
class LogConfig(ConfigBase):
|
||||
"""日志配置类"""
|
||||
|
||||
__ui_label__ = "日志"
|
||||
__ui_label__ = "调试与日志"
|
||||
__ui_icon__ = "file-text"
|
||||
|
||||
date_style: str = Field(
|
||||
@@ -1226,6 +1607,7 @@ class LogConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "volume-x",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""完全屏蔽日志的第三方库列表"""
|
||||
@@ -1235,6 +1617,7 @@ class LogConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "sliders-horizontal",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""特定第三方库的日志级别"""
|
||||
@@ -1258,6 +1641,7 @@ class TelemetryConfig(ConfigBase):
|
||||
class DebugConfig(ConfigBase):
|
||||
"""调试配置类"""
|
||||
|
||||
__ui_parent__ = "log"
|
||||
__ui_label__ = "其他"
|
||||
__ui_icon__ = "more-horizontal"
|
||||
|
||||
@@ -1752,6 +2136,7 @@ class DatabaseConfig(ConfigBase):
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "save",
|
||||
"advanced": True,
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -215,6 +215,17 @@ def _is_available_emoji_record(record: Images) -> bool:
|
||||
return record_path.exists() and record_path.is_file()
|
||||
|
||||
|
||||
def _is_vlm_task_configured() -> bool:
|
||||
"""判断是否配置了可用于表情包识别和审核的视觉模型任务。"""
|
||||
|
||||
try:
|
||||
vlm_models = config_manager.get_model_config().model_task_config.vlm.model_list
|
||||
return any(str(model_name).strip() for model_name in vlm_models)
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取 VLM 模型配置失败,跳过表情包识别和审核: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
# TODO: 修改这个vlm为获取的vlm client,暂时使用这个VLM方法
|
||||
emoji_manager_vlm = LLMServiceClient(task_name="vlm", request_type="emoji.see")
|
||||
emoji_manager_emotion_judge_llm = LLMServiceClient(
|
||||
@@ -316,6 +327,10 @@ class EmojiManager:
|
||||
# 如果提供了字节数据但数据库中没有找到,尝试构建
|
||||
if not emoji_bytes:
|
||||
return None
|
||||
if not _is_vlm_task_configured():
|
||||
await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
logger.info("未配置 VLM 模型,跳过表情包识别、打标签和审核")
|
||||
return None
|
||||
if not wait_for_build:
|
||||
await self.ensure_emoji_saved(emoji_bytes, emoji_hash=emoji_hash)
|
||||
self._schedule_description_build(emoji_hash, emoji_bytes)
|
||||
@@ -386,6 +401,10 @@ class EmojiManager:
|
||||
emoji_hash: 表情包哈希值。
|
||||
emoji_bytes: 表情包字节数据。
|
||||
"""
|
||||
if not _is_vlm_task_configured():
|
||||
logger.info("未配置 VLM 模型,跳过表情包后台识别任务")
|
||||
return
|
||||
|
||||
if emoji_hash in self._pending_description_tasks:
|
||||
return
|
||||
|
||||
@@ -826,6 +845,12 @@ class EmojiManager:
|
||||
Returns:
|
||||
return (Tuple[bool, MaiEmoji]): 返回是否成功构建描述,及表情包对象
|
||||
"""
|
||||
if not _is_vlm_task_configured():
|
||||
logger.info(
|
||||
f"[构建描述] 未配置 VLM 模型,跳过表情包识别、打标签和审核: {target_emoji.file_name}"
|
||||
)
|
||||
return False, target_emoji
|
||||
|
||||
if not target_emoji.file_hash or not target_emoji.image_format:
|
||||
# Should not happen, but just in case
|
||||
await target_emoji.calculate_hash_format()
|
||||
|
||||
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("expressor")
|
||||
|
||||
express_learn_model = LLMServiceClient(
|
||||
task_name="replyer", request_type="expression.learner"
|
||||
task_name="learner", request_type="expression.learner"
|
||||
)
|
||||
summary_model = LLMServiceClient(task_name="utils", request_type="expression.summary")
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from .expression_utils import is_single_char_jargon
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference")
|
||||
llm_inference = LLMServiceClient(task_name="learner", request_type="jargon.inference")
|
||||
|
||||
|
||||
class JargonEntry(TypedDict):
|
||||
|
||||
@@ -26,12 +26,17 @@ class OpenAICompatibleRequestOverrides:
|
||||
def normalize_openai_base_url(base_url: str) -> str:
|
||||
"""规范化 OpenAI 兼容接口的基础地址。
|
||||
|
||||
去掉尾部斜杠,且如果缺少协议前缀则自动补全 http://。
|
||||
|
||||
Args:
|
||||
base_url: 原始基础地址。
|
||||
|
||||
Returns:
|
||||
str: 去掉尾部斜杠后的地址。
|
||||
str: 规范化后的地址。
|
||||
"""
|
||||
base_url = base_url.strip()
|
||||
if base_url and "://" not in base_url:
|
||||
base_url = "http://" + base_url
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
|
||||
@@ -111,6 +111,10 @@ class LLMOrchestrator:
|
||||
task_config = getattr(model_task_config, self.task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到名为 '{self.task_name}' 的任务配置")
|
||||
if self.task_name == "learner" and not any(str(model_name).strip() for model_name in task_config.model_list):
|
||||
fallback_task_config = getattr(model_task_config, "utils", None)
|
||||
if isinstance(fallback_task_config, TaskConfig):
|
||||
return fallback_task_config
|
||||
return task_config
|
||||
|
||||
def _refresh_task_config(self) -> TaskConfig:
|
||||
|
||||
@@ -80,6 +80,7 @@ class MainSystem:
|
||||
init_start_time = time.time()
|
||||
|
||||
await config_manager.start_file_watcher()
|
||||
a_memorix_host_service.register_config_reload_callback()
|
||||
|
||||
# 添加在线时间统计任务
|
||||
await async_task_manager.add_task(OnlineTimeRecordTask())
|
||||
|
||||
@@ -91,7 +91,7 @@ def _should_refresh_image_component(component: ImageComponent) -> bool:
|
||||
"""判断图片组件当前是否仍处于待补全文本的占位状态。"""
|
||||
|
||||
normalized_content = component.content.strip()
|
||||
return not normalized_content or normalized_content == "[图片]"
|
||||
return not normalized_content or normalized_content == "[图片,识别中.....]"
|
||||
|
||||
|
||||
def _should_refresh_emoji_component(component: EmojiComponent) -> bool:
|
||||
|
||||
@@ -63,6 +63,7 @@ class ChatResponse:
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_section: Optional[RenderableType] = None
|
||||
prompt_html_uri: Optional[str] = None
|
||||
|
||||
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
@@ -585,8 +586,9 @@ class MaisakaChatLoopService:
|
||||
all_tools = [item for item in raw_tool_definitions if isinstance(item, dict)]
|
||||
|
||||
prompt_section: RenderableType | None = None
|
||||
prompt_html_uri: str | None = None
|
||||
if global_config.debug.show_maisaka_thinking:
|
||||
prompt_section = PromptCLIVisualizer.build_prompt_section(
|
||||
prompt_section_result = PromptCLIVisualizer.build_prompt_section_result(
|
||||
built_messages,
|
||||
category="planner" if request_kind != "timing_gate" else "timing_gate",
|
||||
chat_id=self._session_id,
|
||||
@@ -595,6 +597,9 @@ class MaisakaChatLoopService:
|
||||
folded=global_config.debug.fold_maisaka_thinking,
|
||||
tool_definitions=list(all_tools),
|
||||
)
|
||||
prompt_section = prompt_section_result.panel
|
||||
if prompt_section_result.preview_access is not None:
|
||||
prompt_html_uri = prompt_section_result.preview_access.viewer_web_uri
|
||||
|
||||
llm_chat = self._get_llm_chat_client(request_kind)
|
||||
generation_result = await llm_chat.generate_response_with_messages(
|
||||
@@ -660,6 +665,7 @@ class MaisakaChatLoopService:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_section=prompt_section,
|
||||
prompt_html_uri=prompt_html_uri,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -83,7 +83,7 @@ def _append_image_component(
|
||||
builder.add_text_content(normalized_content)
|
||||
return True
|
||||
|
||||
builder.add_text_content("[图片]")
|
||||
builder.add_text_content("[图片,识别中.....]")
|
||||
return True
|
||||
|
||||
|
||||
@@ -147,7 +147,7 @@ def _render_component_for_prompt(component: StandardMessageComponents) -> str:
|
||||
return (component.text or "").strip()
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
return component.content.strip() if component.content else "[图片]"
|
||||
return component.content.strip() if component.content else "[图片,识别中.....]"
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
return component.content.strip() if component.content else "[表情包]"
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal
|
||||
from urllib.parse import quote
|
||||
|
||||
import hashlib
|
||||
import html
|
||||
@@ -32,6 +33,36 @@ from .prompt_preview_logger import PromptPreviewLogger
|
||||
DATA_IMAGE_DIR = REPO_ROOT / "data" / "images"
|
||||
|
||||
|
||||
def _build_prompt_preview_web_uri(file_path: Path) -> str:
|
||||
"""构建 WebUI 可访问的 Prompt 预览地址。"""
|
||||
|
||||
try:
|
||||
relative_path = file_path.resolve().relative_to(PromptPreviewLogger._BASE_DIR.resolve())
|
||||
except ValueError:
|
||||
return build_file_uri(file_path)
|
||||
return f"/api/webui/config/maisaka-prompt-preview?path={quote(relative_path.as_posix(), safe='')}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptPreviewAccess:
|
||||
"""Prompt 预览文件的展示入口和可直接打开的路径。"""
|
||||
|
||||
body: RenderableType
|
||||
viewer_path: Path
|
||||
viewer_uri: str
|
||||
viewer_web_uri: str
|
||||
dump_path: Path
|
||||
dump_uri: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptSectionResult:
|
||||
"""Prompt 面板及其可选 HTML 预览入口。"""
|
||||
|
||||
panel: Panel
|
||||
preview_access: PromptPreviewAccess | None = None
|
||||
|
||||
|
||||
class PromptImageDisplayMode(str, Enum):
|
||||
"""图片在终端中的展示模式。"""
|
||||
|
||||
@@ -470,6 +501,77 @@ class PromptCLIVisualizer:
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_prompt_preview_access(
|
||||
cls,
|
||||
messages: list[Any],
|
||||
*,
|
||||
category: str,
|
||||
chat_id: str,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> PromptPreviewAccess:
|
||||
"""保存 Prompt 预览文件,并返回 CLI 展示入口与浏览器可打开的 URI。"""
|
||||
|
||||
viewer_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
viewer_messages.append(dict(message))
|
||||
continue
|
||||
|
||||
normalized_message = {
|
||||
"content": getattr(message, "content", None),
|
||||
"role": getattr(getattr(message, "role", "unknown"), "value", getattr(message, "role", "unknown")),
|
||||
}
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
if tool_call_id:
|
||||
normalized_message["tool_call_id"] = tool_call_id
|
||||
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
normalized_message["tool_calls"] = [
|
||||
cls.format_tool_call_for_display(tool_call) for tool_call in tool_calls
|
||||
]
|
||||
viewer_messages.append(normalized_message)
|
||||
|
||||
prompt_dump_text = cls._build_prompt_dump_text(messages)
|
||||
tool_definition_dump_text = cls._build_tool_definition_dump_text(tool_definitions)
|
||||
if tool_definition_dump_text:
|
||||
prompt_dump_text = f"{prompt_dump_text}\n\n{'=' * 80}\n\n{tool_definition_dump_text}"
|
||||
viewer_html_text = cls._build_prompt_viewer_html(
|
||||
viewer_messages,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
saved_paths = PromptPreviewLogger.save_preview_files(
|
||||
chat_id,
|
||||
category,
|
||||
{
|
||||
".html": viewer_html_text,
|
||||
".txt": prompt_dump_text,
|
||||
},
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
prompt_dump_path = saved_paths[".txt"]
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="html预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="在浏览器打开 Prompt",
|
||||
dump_label="原始文本",
|
||||
dump_path=prompt_dump_path,
|
||||
dump_link_text="点击打开 Prompt 文本",
|
||||
)
|
||||
return PromptPreviewAccess(
|
||||
body=body,
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_uri=build_file_uri(viewer_html_path),
|
||||
viewer_web_uri=_build_prompt_preview_web_uri(viewer_html_path),
|
||||
dump_path=prompt_dump_path,
|
||||
dump_uri=build_file_uri(prompt_dump_path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_html_role_class(cls, role: str) -> str:
|
||||
return {
|
||||
@@ -804,56 +906,14 @@ class PromptCLIVisualizer:
|
||||
) -> RenderableType:
|
||||
"""构建用于查看完整 prompt 的折叠入口内容。"""
|
||||
|
||||
viewer_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
viewer_messages.append(dict(message))
|
||||
continue
|
||||
|
||||
normalized_message = {
|
||||
"content": getattr(message, "content", None),
|
||||
"role": getattr(getattr(message, "role", "unknown"), "value", getattr(message, "role", "unknown")),
|
||||
}
|
||||
tool_call_id = getattr(message, "tool_call_id", None)
|
||||
if tool_call_id:
|
||||
normalized_message["tool_call_id"] = tool_call_id
|
||||
|
||||
tool_calls = getattr(message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
normalized_message["tool_calls"] = [
|
||||
cls.format_tool_call_for_display(tool_call) for tool_call in tool_calls
|
||||
]
|
||||
viewer_messages.append(normalized_message)
|
||||
|
||||
prompt_dump_text = cls._build_prompt_dump_text(messages)
|
||||
tool_definition_dump_text = cls._build_tool_definition_dump_text(tool_definitions)
|
||||
if tool_definition_dump_text:
|
||||
prompt_dump_text = f"{prompt_dump_text}\n\n{'=' * 80}\n\n{tool_definition_dump_text}"
|
||||
viewer_html_text = cls._build_prompt_viewer_html(
|
||||
viewer_messages,
|
||||
return cls.build_prompt_preview_access(
|
||||
messages,
|
||||
category=category,
|
||||
chat_id=chat_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
saved_paths = PromptPreviewLogger.save_preview_files(
|
||||
chat_id,
|
||||
category,
|
||||
{
|
||||
".html": viewer_html_text,
|
||||
".txt": prompt_dump_text,
|
||||
},
|
||||
)
|
||||
viewer_html_path = saved_paths[".html"]
|
||||
prompt_dump_path = saved_paths[".txt"]
|
||||
body = cls._build_preview_access_body(
|
||||
viewer_label="html预览",
|
||||
viewer_path=viewer_html_path,
|
||||
viewer_link_text="在浏览器打开 Prompt",
|
||||
dump_label="原始文本",
|
||||
dump_path=prompt_dump_path,
|
||||
dump_link_text="点击打开 Prompt 文本",
|
||||
)
|
||||
return body
|
||||
).body
|
||||
|
||||
@classmethod
|
||||
def build_prompt_section(
|
||||
@@ -870,26 +930,56 @@ class PromptCLIVisualizer:
|
||||
) -> Panel:
|
||||
"""构建用于嵌入结果面板中的 Prompt 区块。"""
|
||||
|
||||
return cls.build_prompt_section_result(
|
||||
messages,
|
||||
category=category,
|
||||
chat_id=chat_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
image_display_mode=image_display_mode,
|
||||
folded=folded,
|
||||
tool_definitions=tool_definitions,
|
||||
).panel
|
||||
|
||||
@classmethod
|
||||
def build_prompt_section_result(
|
||||
cls,
|
||||
messages: list[Any],
|
||||
*,
|
||||
category: str,
|
||||
chat_id: str,
|
||||
request_kind: str,
|
||||
selection_reason: str,
|
||||
image_display_mode: Literal["legacy", "path_link"] = "path_link",
|
||||
folded: bool,
|
||||
tool_definitions: list[dict[str, Any]] | None = None,
|
||||
) -> PromptSectionResult:
|
||||
"""构建 Prompt 面板,并在折叠模式下返回对应的 HTML 预览入口。"""
|
||||
|
||||
panel_title, panel_border_style = cls.get_request_panel_style(request_kind)
|
||||
preview_access = cls.build_prompt_preview_access(
|
||||
messages,
|
||||
category=category,
|
||||
chat_id=chat_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
if folded:
|
||||
prompt_renderable = cls.build_prompt_access_panel(
|
||||
messages,
|
||||
category=category,
|
||||
chat_id=chat_id,
|
||||
request_kind=request_kind,
|
||||
selection_reason=selection_reason,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
prompt_renderable = preview_access.body
|
||||
else:
|
||||
ordered_panels = cls.build_prompt_panels(messages)
|
||||
prompt_renderable = Group(*ordered_panels)
|
||||
|
||||
return Panel(
|
||||
prompt_renderable,
|
||||
title=panel_title,
|
||||
subtitle=selection_reason,
|
||||
border_style=panel_border_style,
|
||||
padding=(0, 1),
|
||||
return PromptSectionResult(
|
||||
panel=Panel(
|
||||
prompt_renderable,
|
||||
title=panel_title,
|
||||
subtitle=selection_reason,
|
||||
border_style=panel_border_style,
|
||||
padding=(0, 1),
|
||||
),
|
||||
preview_access=preview_access,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -95,7 +95,7 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str:
|
||||
continue
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
append_visible_part(component.content.strip() or "[图片]")
|
||||
append_visible_part(component.content.strip() or "[图片,识别中.....]")
|
||||
continue
|
||||
|
||||
if isinstance(component, AtComponent):
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
import json
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -57,7 +58,7 @@ def _extract_text_content(content: Any) -> Optional[str]:
|
||||
if block_type == "text":
|
||||
text_parts.append(str(block.get("text", "")))
|
||||
elif block_type == "image_url":
|
||||
text_parts.append("[图片]")
|
||||
text_parts.append("[图片,识别中.....]")
|
||||
else:
|
||||
text_parts.append(f"[{block_type}]")
|
||||
elif isinstance(block, str):
|
||||
@@ -66,43 +67,65 @@ def _extract_text_content(content: Any) -> Optional[str]:
|
||||
return str(content)
|
||||
|
||||
|
||||
def _normalize_tool_call_arguments(arguments: Any) -> tuple[Any, Optional[str]]:
|
||||
"""标准化工具调用参数,兼容 JSON 字符串和对象。"""
|
||||
|
||||
if isinstance(arguments, str):
|
||||
raw_arguments = arguments
|
||||
try:
|
||||
parsed_arguments = json.loads(arguments) if arguments.strip() else {}
|
||||
except json.JSONDecodeError:
|
||||
return {}, raw_arguments
|
||||
return _normalize_payload_value(parsed_arguments), raw_arguments
|
||||
return _normalize_payload_value(arguments or {}), None
|
||||
|
||||
|
||||
def _serialize_single_tool_call(tool_call: Any) -> Dict[str, Any]:
|
||||
"""将不同来源的 tool_call 标准化为前端可直接展示的结构。"""
|
||||
|
||||
if isinstance(tool_call, dict):
|
||||
function_info = tool_call.get("function")
|
||||
if isinstance(function_info, dict):
|
||||
raw_arguments = function_info.get("arguments", tool_call.get("arguments", tool_call.get("args", {})))
|
||||
name = function_info.get("name", tool_call.get("name", tool_call.get("func_name", "unknown")))
|
||||
else:
|
||||
raw_arguments = tool_call.get("arguments", tool_call.get("args", {}))
|
||||
name = tool_call.get("name", tool_call.get("func_name", "unknown"))
|
||||
|
||||
arguments, arguments_raw = _normalize_tool_call_arguments(raw_arguments)
|
||||
serialized: Dict[str, Any] = {
|
||||
"id": str(tool_call.get("id", tool_call.get("call_id", ""))),
|
||||
"name": str(name or "unknown"),
|
||||
"arguments": arguments,
|
||||
}
|
||||
if arguments_raw is not None:
|
||||
serialized["arguments_raw"] = arguments_raw
|
||||
return serialized
|
||||
|
||||
raw_arguments = getattr(tool_call, "args", None)
|
||||
if raw_arguments is None:
|
||||
raw_arguments = getattr(tool_call, "arguments", None)
|
||||
arguments, arguments_raw = _normalize_tool_call_arguments(raw_arguments)
|
||||
serialized = {
|
||||
"id": str(getattr(tool_call, "id", None) or getattr(tool_call, "call_id", "")),
|
||||
"name": str(getattr(tool_call, "func_name", None) or getattr(tool_call, "name", "unknown")),
|
||||
"arguments": arguments,
|
||||
}
|
||||
if arguments_raw is not None:
|
||||
serialized["arguments_raw"] = arguments_raw
|
||||
return serialized
|
||||
|
||||
|
||||
def _serialize_tool_calls_from_objects(tool_calls: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""将工具调用对象列表序列化为字典列表。"""
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
serialized: Dict[str, Any] = {
|
||||
"id": getattr(tool_call, "id", None) or getattr(tool_call, "call_id", ""),
|
||||
"name": getattr(tool_call, "func_name", None) or getattr(tool_call, "name", "unknown"),
|
||||
}
|
||||
args = getattr(tool_call, "args", None) or getattr(tool_call, "arguments", None)
|
||||
if isinstance(args, dict):
|
||||
serialized["arguments"] = _normalize_payload_value(args)
|
||||
elif isinstance(args, str):
|
||||
serialized["arguments_raw"] = args
|
||||
result.append(serialized)
|
||||
return result
|
||||
return [_serialize_single_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def _serialize_tool_calls_from_dicts(tool_calls: List[Any]) -> List[Dict[str, Any]]:
|
||||
"""将工具调用字典列表标准化为可传输格式。"""
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
result.append({
|
||||
"id": str(tool_call.get("id", "")),
|
||||
"name": str(tool_call.get("name", tool_call.get("func_name", "unknown"))),
|
||||
"arguments": _normalize_payload_value(tool_call.get("arguments", tool_call.get("args", {}))),
|
||||
})
|
||||
continue
|
||||
|
||||
result.append({
|
||||
"id": str(getattr(tool_call, "id", getattr(tool_call, "call_id", ""))),
|
||||
"name": str(getattr(tool_call, "func_name", getattr(tool_call, "name", "unknown"))),
|
||||
"arguments": _normalize_payload_value(getattr(tool_call, "args", getattr(tool_call, "arguments", {}))),
|
||||
})
|
||||
return result
|
||||
return [_serialize_single_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def _serialize_message(message: Any) -> Dict[str, Any]:
|
||||
@@ -143,6 +166,33 @@ def _serialize_messages(messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return [_serialize_message(message) for message in messages]
|
||||
|
||||
|
||||
def _enrich_session_identity(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""为监控事件补充会话展示所需的群/用户标识。"""
|
||||
|
||||
session_id = data.get("session_id")
|
||||
if not session_id:
|
||||
return data
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
chat_stream = chat_manager.get_session_by_session_id(str(session_id))
|
||||
except Exception:
|
||||
return data
|
||||
|
||||
if chat_stream is None:
|
||||
return data
|
||||
|
||||
session_name = chat_manager.get_session_name(str(session_id))
|
||||
if session_name:
|
||||
data.setdefault("session_name", session_name)
|
||||
data.setdefault("is_group_chat", chat_stream.is_group_session)
|
||||
data.setdefault("group_id", chat_stream.group_id)
|
||||
data.setdefault("user_id", chat_stream.user_id)
|
||||
data.setdefault("platform", chat_stream.platform)
|
||||
return data
|
||||
|
||||
|
||||
def _serialize_tool_results(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""标准化最终 planner 卡中的工具结果列表。"""
|
||||
|
||||
@@ -187,6 +237,7 @@ def _serialize_planner_block(
|
||||
completion_tokens: Optional[int],
|
||||
total_tokens: Optional[int],
|
||||
duration_ms: Optional[float],
|
||||
prompt_html_uri: Optional[str] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""标准化 planner 结果区块。"""
|
||||
|
||||
@@ -197,6 +248,7 @@ def _serialize_planner_block(
|
||||
and completion_tokens is None
|
||||
and total_tokens is None
|
||||
and duration_ms is None
|
||||
and prompt_html_uri is None
|
||||
):
|
||||
return None
|
||||
|
||||
@@ -207,6 +259,7 @@ def _serialize_planner_block(
|
||||
"completion_tokens": int(completion_tokens or 0),
|
||||
"total_tokens": int(total_tokens or 0),
|
||||
"duration_ms": float(duration_ms or 0.0),
|
||||
"prompt_html_uri": str(prompt_html_uri or ""),
|
||||
}
|
||||
|
||||
|
||||
@@ -266,6 +319,7 @@ async def _broadcast(event: str, data: Dict[str, Any]) -> None:
|
||||
try:
|
||||
from src.webui.routers.websocket.manager import websocket_manager
|
||||
|
||||
data = _enrich_session_identity(data)
|
||||
subscription_key = f"{MONITOR_DOMAIN}:{MONITOR_TOPIC}"
|
||||
total_connections = len(websocket_manager.connections)
|
||||
subscriber_count = sum(
|
||||
@@ -291,12 +345,24 @@ async def _broadcast(event: str, data: Dict[str, Any]) -> None:
|
||||
logger.warning(f"MaiSaka 监控事件广播失败: {exc}", exc_info=True)
|
||||
|
||||
|
||||
async def emit_session_start(session_id: str, session_name: str) -> None:
|
||||
async def emit_session_start(
|
||||
session_id: str,
|
||||
session_name: str,
|
||||
*,
|
||||
is_group_chat: bool,
|
||||
group_id: Optional[str],
|
||||
user_id: Optional[str],
|
||||
platform: str,
|
||||
) -> None:
|
||||
"""广播会话开始事件。"""
|
||||
|
||||
await _broadcast("session.start", {
|
||||
"session_id": session_id,
|
||||
"session_name": session_name,
|
||||
"is_group_chat": is_group_chat,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"platform": platform,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
@@ -389,6 +455,7 @@ async def emit_planner_finalized(
|
||||
planner_completion_tokens: Optional[int],
|
||||
planner_total_tokens: Optional[int],
|
||||
planner_duration_ms: Optional[float],
|
||||
planner_prompt_html_uri: Optional[str],
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
time_records: Dict[str, float],
|
||||
agent_state: str,
|
||||
@@ -424,6 +491,7 @@ async def emit_planner_finalized(
|
||||
planner_completion_tokens,
|
||||
planner_total_tokens,
|
||||
planner_duration_ms,
|
||||
planner_prompt_html_uri,
|
||||
),
|
||||
"tools": _serialize_tool_results(list(tools or [])),
|
||||
"final_state": {
|
||||
|
||||
@@ -709,6 +709,7 @@ class MaisakaReasoningEngine:
|
||||
),
|
||||
planner_total_tokens=response.total_tokens if response is not None else None,
|
||||
planner_duration_ms=planner_duration_ms if response is not None else None,
|
||||
planner_prompt_html_uri=response.prompt_html_uri if response is not None else None,
|
||||
tools=tool_monitor_results,
|
||||
time_records=dict(completed_cycle.time_records),
|
||||
agent_state=self._runtime._agent_state,
|
||||
|
||||
@@ -46,6 +46,7 @@ from .display.display_utils import build_tool_call_summary_lines, format_token_c
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.stage_status_board import remove_stage_status, update_stage_status
|
||||
from .history_utils import drop_leading_orphan_tool_results
|
||||
from .monitor_events import emit_session_start
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .reply_effect import ReplyEffectTracker
|
||||
from .reply_effect.image_utils import extract_visual_attachments_from_sequence
|
||||
@@ -126,7 +127,7 @@ class MaisakaHeartFlowChatting:
|
||||
int(global_config.chat.planner_interrupt_max_consecutive_count),
|
||||
)
|
||||
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
expr_use, expr_learn, jargon_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
self._enable_expression_learning = expr_learn
|
||||
self._enable_jargon_learning = jargon_learn
|
||||
@@ -136,6 +137,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._jargon_miner = JargonMiner(session_id, session_name=session_name)
|
||||
|
||||
self._reasoning_engine = MaisakaReasoningEngine(self)
|
||||
self._monitor_session_start_task: Optional[asyncio.Task[None]] = None
|
||||
self._tool_registry = ToolRegistry()
|
||||
self._reply_effect_tracker = ReplyEffectTracker(
|
||||
session_id=self.session_id,
|
||||
@@ -144,6 +146,24 @@ class MaisakaHeartFlowChatting:
|
||||
judge_runner=self._run_reply_effect_judge,
|
||||
)
|
||||
self._register_tool_providers()
|
||||
self._emit_monitor_session_start()
|
||||
|
||||
def _emit_monitor_session_start(self) -> None:
|
||||
"""向 WebUI 监控面板同步当前会话的展示标识。"""
|
||||
|
||||
try:
|
||||
self._monitor_session_start_task = asyncio.create_task(
|
||||
emit_session_start(
|
||||
session_id=self.session_id,
|
||||
session_name=self.session_name,
|
||||
is_group_chat=self.chat_stream.is_group_session,
|
||||
group_id=self.chat_stream.group_id,
|
||||
user_id=self.chat_stream.user_id,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
logger.debug("MaiSaka 监控会话开始事件未发送:当前没有运行中的事件循环")
|
||||
|
||||
@staticmethod
|
||||
def _is_reply_effect_tracking_enabled() -> bool:
|
||||
|
||||
@@ -216,6 +216,20 @@ class PluginRunnerSupervisor:
|
||||
"""
|
||||
return {plugin_id: registration.plugin_version for plugin_id, registration in self._registered_plugins.items()}
|
||||
|
||||
def get_plugin_load_statuses(self) -> Dict[str, str]:
|
||||
"""返回 Runner 最近一次上报的插件加载状态。"""
|
||||
|
||||
statuses: Dict[str, str] = {}
|
||||
for plugin_id in self._runner_ready_payloads.loaded_plugins:
|
||||
statuses[plugin_id] = "success"
|
||||
for plugin_id in self._runner_ready_payloads.failed_plugins:
|
||||
statuses[plugin_id] = "failed"
|
||||
for plugin_id in self._runner_ready_payloads.inactive_plugins:
|
||||
statuses.setdefault(plugin_id, "inactive")
|
||||
for plugin_id in self._registered_plugins:
|
||||
statuses[plugin_id] = "success"
|
||||
return statuses
|
||||
|
||||
def set_blocked_plugin_reasons(self, blocked_plugin_reasons: Dict[str, str]) -> None:
|
||||
"""设置当前 Runner 启动时应拒绝加载的插件列表。
|
||||
|
||||
|
||||
@@ -657,6 +657,14 @@ class PluginRuntimeManager(
|
||||
plugin_id: supervisor for supervisor in self.supervisors for plugin_id in supervisor.get_loaded_plugin_ids()
|
||||
}
|
||||
|
||||
def get_plugin_load_statuses(self) -> Dict[str, str]:
|
||||
"""汇总所有 Supervisor 上报的插件加载状态。"""
|
||||
|
||||
statuses: Dict[str, str] = {}
|
||||
for supervisor in self.supervisors:
|
||||
statuses.update(supervisor.get_plugin_load_statuses())
|
||||
return statuses
|
||||
|
||||
def _build_external_available_plugins_for_supervisor(self, target_supervisor: "PluginSupervisor") -> Dict[str, str]:
|
||||
"""收集某个 Supervisor 可用的外部插件版本映射。"""
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("plugin_runtime.runner.manifest_validator")
|
||||
|
||||
_SEMVER_PATTERN = re.compile(r"^\d+\.\d+\.\d+$")
|
||||
_PLUGIN_ID_PATTERN = re.compile(r"^[a-z0-9]+(?:[.-][a-z0-9]+)+$")
|
||||
_PLUGIN_ID_PATTERN = re.compile(r"^[A-Za-z0-9_]+(?:[.-][A-Za-z0-9_]+)+$")
|
||||
_PACKAGE_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
_HTTP_URL_PATTERN = re.compile(r"^https?://.+$")
|
||||
|
||||
@@ -379,7 +379,7 @@ class PluginDependencyDefinition(_StrictManifestModel):
|
||||
ValueError: 当 ID 不符合规则时抛出。
|
||||
"""
|
||||
if not _PLUGIN_ID_PATTERN.fullmatch(value):
|
||||
raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
|
||||
raise ValueError("必须使用字母/数字/下划线,并以点号或横线分隔,例如 github.author.plugin")
|
||||
return value
|
||||
|
||||
@field_validator("version_spec")
|
||||
@@ -548,7 +548,7 @@ class PluginManifest(_StrictManifestModel):
|
||||
if not value:
|
||||
raise ValueError("不能为空")
|
||||
if info.field_name == "id" and not _PLUGIN_ID_PATTERN.fullmatch(value):
|
||||
raise ValueError("必须使用小写字母/数字,并以点号或横线分隔,例如 github.author.plugin")
|
||||
raise ValueError("必须使用字母/数字/下划线,并以点号或横线分隔,例如 github.author.plugin")
|
||||
return value
|
||||
|
||||
@field_validator("capabilities")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
|
||||
|
||||
from importlib import import_module
|
||||
from os import getenv
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
@@ -16,6 +17,7 @@ from src.common.logger import get_logger
|
||||
logger = get_logger("webui.app")
|
||||
|
||||
_DASHBOARD_PACKAGE_NAME = "maibot-dashboard"
|
||||
_LOCAL_DASHBOARD_ENV = "MAIBOT_WEBUI_USE_LOCAL_DASHBOARD"
|
||||
_MANUAL_INSTALL_COMMAND = f"pip install {_DASHBOARD_PACKAGE_NAME}"
|
||||
|
||||
|
||||
@@ -36,6 +38,10 @@ def _get_project_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
def _is_local_dashboard_enabled() -> bool:
|
||||
return getenv(_LOCAL_DASHBOARD_ENV, "").strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _validate_static_path(static_path: Path | None) -> Tuple[str, Dict[str, Any]] | None:
|
||||
if static_path is None:
|
||||
return "startup.webui_static_dir_missing", {}
|
||||
@@ -179,6 +185,16 @@ def _setup_static_files(app: FastAPI):
|
||||
logger.warning(t("startup.webui_dashboard_package_hint", command=_MANUAL_INSTALL_COMMAND))
|
||||
return
|
||||
|
||||
@app.get("/maibot_statistics.html", include_in_schema=False)
|
||||
async def serve_statistics_report():
|
||||
report_path = (_get_project_root() / "maibot_statistics.html").resolve()
|
||||
if not report_path.exists() or not report_path.is_file():
|
||||
raise HTTPException(status_code=404, detail=t("core.not_found"))
|
||||
|
||||
response = FileResponse(report_path, media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
@app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(full_path: str):
|
||||
if not full_path or full_path == "/":
|
||||
@@ -205,12 +221,10 @@ def _setup_static_files(app: FastAPI):
|
||||
|
||||
|
||||
def _resolve_static_path() -> Path | None:
|
||||
# 临时仅允许使用已安装的 maibot-dashboard 包,不使用仓库本地 dashboard/dist。
|
||||
# 如需恢复本地回退逻辑,可取消下方注释。
|
||||
# base_dir = _get_project_root()
|
||||
# static_path = base_dir / "dashboard" / "dist"
|
||||
# if static_path.is_dir() and (static_path / "index.html").exists():
|
||||
# return static_path
|
||||
if _is_local_dashboard_enabled():
|
||||
static_path = _get_project_root() / "dashboard" / "dist"
|
||||
if static_path.is_dir() and (static_path / "index.html").exists():
|
||||
return static_path
|
||||
|
||||
try:
|
||||
module = import_module("maibot_dashboard")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Dict, List, get_args, get_origin
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, get_args, get_origin
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
@@ -40,12 +39,15 @@ class ConfigSchemaGenerator:
|
||||
ui_parent = getattr(config_class, "__ui_parent__", "")
|
||||
ui_label = getattr(config_class, "__ui_label__", "")
|
||||
ui_icon = getattr(config_class, "__ui_icon__", "")
|
||||
ui_merge_children = getattr(config_class, "__ui_merge_children__", [])
|
||||
if ui_parent:
|
||||
schema["uiParent"] = ui_parent
|
||||
if ui_label:
|
||||
schema["uiLabel"] = ui_label
|
||||
if ui_icon:
|
||||
schema["uiIcon"] = ui_icon
|
||||
if ui_merge_children:
|
||||
schema["uiMergeChildren"] = list(ui_merge_children)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@ from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, List, Tuple
|
||||
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR, PROJECT_ROOT, Config, ModelConfig
|
||||
@@ -19,6 +21,7 @@ from src.config.model_configs import (
|
||||
ModelTaskConfig,
|
||||
)
|
||||
from src.config.official_configs import (
|
||||
AMemorixConfig,
|
||||
BotConfig,
|
||||
ChatConfig,
|
||||
ChineseTypoConfig,
|
||||
@@ -46,9 +49,76 @@ ConfigBody = Annotated[Dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[Dict[str, str], Body()]
|
||||
PromptContentBody = Annotated[str, Body(embed=True)]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"], dependencies=[Depends(require_auth)])
|
||||
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
MAISAKA_PROMPT_PREVIEW_DIR = (PROJECT_ROOT / "logs" / "maisaka_prompt").resolve()
|
||||
|
||||
|
||||
class PromptFileInfo(BaseModel):
|
||||
"""Prompt 文件信息。"""
|
||||
|
||||
name: str = Field(..., description="Prompt 文件名")
|
||||
size: int = Field(..., description="文件大小")
|
||||
modified_at: float = Field(..., description="最后修改时间戳")
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
"""Prompt 目录响应。"""
|
||||
|
||||
success: bool = True
|
||||
languages: List[str]
|
||||
files: Dict[str, List[PromptFileInfo]]
|
||||
|
||||
|
||||
class PromptFileResponse(BaseModel):
|
||||
"""Prompt 文件内容响应。"""
|
||||
|
||||
success: bool = True
|
||||
language: str
|
||||
filename: str
|
||||
content: str
|
||||
|
||||
|
||||
def _safe_prompt_path(language: str, filename: str) -> Path:
|
||||
"""校验并解析 prompts 下的文件路径。"""
|
||||
|
||||
normalized_language = language.strip()
|
||||
normalized_filename = filename.strip()
|
||||
|
||||
if not normalized_language or any(part in normalized_language for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 语言目录")
|
||||
if not normalized_filename.endswith(".prompt") or any(part in normalized_filename for part in ("..", "/", "\\")):
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 文件名")
|
||||
|
||||
prompt_path = (PROMPTS_DIR / normalized_language / normalized_filename).resolve()
|
||||
prompts_root = PROMPTS_DIR.resolve()
|
||||
try:
|
||||
prompt_path.relative_to(prompts_root)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Prompt 路径越界") from exc
|
||||
return prompt_path
|
||||
|
||||
|
||||
def _safe_maisaka_prompt_preview_path(relative_path: str) -> Path:
|
||||
"""校验并解析 MaiSaka Prompt HTML 预览路径。"""
|
||||
|
||||
normalized_path = relative_path.strip().replace("\\", "/")
|
||||
if not normalized_path or normalized_path.startswith("/") or ".." in Path(normalized_path).parts:
|
||||
raise HTTPException(status_code=400, detail="无效的 Prompt 预览路径")
|
||||
|
||||
preview_path = (MAISAKA_PROMPT_PREVIEW_DIR / normalized_path).resolve()
|
||||
try:
|
||||
preview_path.relative_to(MAISAKA_PROMPT_PREVIEW_DIR)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail="Prompt 预览路径越界") from exc
|
||||
|
||||
if preview_path.suffix.lower() != ".html":
|
||||
raise HTTPException(status_code=400, detail="只允许打开 HTML Prompt 预览")
|
||||
return preview_path
|
||||
|
||||
|
||||
def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
"""递归转换 tomlkit 文档/Table 为纯 Python 字典,避免 from_dict 触发 tomlkit __setitem__"""
|
||||
@@ -62,6 +132,87 @@ def _toml_to_plain_dict(obj: Any) -> Any:
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/prompts", response_model=PromptCatalogResponse)
|
||||
async def list_prompt_files():
|
||||
"""列出 prompts 目录下的语言和 Prompt 文件。"""
|
||||
|
||||
try:
|
||||
if not PROMPTS_DIR.exists():
|
||||
return PromptCatalogResponse(languages=[], files={})
|
||||
|
||||
languages: List[str] = []
|
||||
files: Dict[str, List[PromptFileInfo]] = {}
|
||||
for language_dir in sorted(PROMPTS_DIR.iterdir(), key=lambda item: item.name):
|
||||
if not language_dir.is_dir():
|
||||
continue
|
||||
|
||||
language = language_dir.name
|
||||
prompt_files: List[PromptFileInfo] = []
|
||||
for prompt_file in sorted(language_dir.glob("*.prompt"), key=lambda item: item.name):
|
||||
stat = prompt_file.stat()
|
||||
prompt_files.append(
|
||||
PromptFileInfo(
|
||||
name=prompt_file.name,
|
||||
size=stat.st_size,
|
||||
modified_at=stat.st_mtime,
|
||||
)
|
||||
)
|
||||
|
||||
languages.append(language)
|
||||
files[language] = prompt_files
|
||||
|
||||
return PromptCatalogResponse(languages=languages, files=files)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"列出 Prompt 文件失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"列出 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
async def get_prompt_file(language: str, filename: str):
|
||||
"""读取指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
content = prompt_path.read_text(encoding="utf-8")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"读取 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"读取 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/prompts/{language}/{filename}", response_model=PromptFileResponse)
|
||||
async def update_prompt_file(language: str, filename: str, content: PromptContentBody):
|
||||
"""更新指定语言下的 Prompt 文件内容。"""
|
||||
|
||||
prompt_path = _safe_prompt_path(language, filename)
|
||||
if not prompt_path.parent.exists() or not prompt_path.parent.is_dir():
|
||||
raise HTTPException(status_code=404, detail="Prompt 语言目录不存在")
|
||||
if not prompt_path.exists() or not prompt_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 文件不存在")
|
||||
|
||||
try:
|
||||
prompt_path.write_text(content, encoding="utf-8", newline="\n")
|
||||
return PromptFileResponse(language=language, filename=filename, content=content)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 Prompt 文件失败: {prompt_path} {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"保存 Prompt 文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/maisaka-prompt-preview", response_class=FileResponse)
|
||||
async def get_maisaka_prompt_preview(path: str = Query(..., description="logs/maisaka_prompt 下的相对 HTML 路径")):
|
||||
"""打开 MaiSaka 监控中生成的 Prompt HTML 预览。"""
|
||||
|
||||
preview_path = _safe_maisaka_prompt_preview_path(path)
|
||||
if not preview_path.exists() or not preview_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="Prompt 预览文件不存在")
|
||||
return FileResponse(preview_path, media_type="text/html")
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema():
|
||||
"""获取麦麦主程序配置架构"""
|
||||
@@ -128,6 +279,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"telemetry": TelemetryConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"memory": MemoryConfig,
|
||||
"a_memorix": AMemorixConfig,
|
||||
"debug": DebugConfig,
|
||||
"voice": VoiceConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi import APIRouter, Body, Depends, File, Form, HTTPException, Query,
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.A_memorix.host_service import a_memorix_host_service
|
||||
from src.person_info.person_info import resolve_person_id_for_memory
|
||||
from src.services.memory_service import MemorySearchResult, memory_service
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
@@ -336,10 +337,25 @@ async def _episode_process_pending(payload: EpisodeProcessPendingRequest) -> dic
|
||||
)
|
||||
|
||||
|
||||
async def _profile_query(*, person_id: str, person_keyword: str, limit: int, force_refresh: bool) -> dict:
|
||||
async def _profile_query(
|
||||
*,
|
||||
person_id: str,
|
||||
person_keyword: str,
|
||||
platform: str,
|
||||
user_id: str,
|
||||
limit: int,
|
||||
force_refresh: bool,
|
||||
) -> dict:
|
||||
clean_person_id = str(person_id or "").strip()
|
||||
if not clean_person_id and str(platform or "").strip() and str(user_id or "").strip():
|
||||
clean_person_id = resolve_person_id_for_memory(
|
||||
platform=str(platform or "").strip(),
|
||||
user_id=str(user_id or "").strip(),
|
||||
strict_known=False,
|
||||
)
|
||||
return await memory_service.profile_admin(
|
||||
action="query",
|
||||
person_id=person_id,
|
||||
person_id=clean_person_id,
|
||||
person_keyword=person_keyword,
|
||||
limit=limit,
|
||||
force_refresh=force_refresh,
|
||||
@@ -834,12 +850,16 @@ async def process_memory_episode_pending(payload: EpisodeProcessPendingRequest):
|
||||
async def query_memory_profile(
|
||||
person_id: str = Query(""),
|
||||
person_keyword: str = Query(""),
|
||||
platform: str = Query(""),
|
||||
user_id: str = Query(""),
|
||||
limit: int = Query(12, ge=1, le=100),
|
||||
force_refresh: bool = Query(False),
|
||||
):
|
||||
return await _profile_query(
|
||||
person_id=person_id,
|
||||
person_keyword=person_keyword,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
@@ -1306,12 +1326,16 @@ async def compat_process_episode_pending(payload: EpisodeProcessPendingRequest):
|
||||
async def compat_profile_query(
|
||||
person_id: str = Query(""),
|
||||
person_keyword: str = Query(""),
|
||||
platform: str = Query(""),
|
||||
user_id: str = Query(""),
|
||||
limit: int = Query(12, ge=1, le=100),
|
||||
force_refresh: bool = Query(False),
|
||||
):
|
||||
return await _profile_query(
|
||||
person_id=person_id,
|
||||
person_keyword=person_keyword,
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException
|
||||
import json
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.services.git_mirror_service import get_git_mirror_service
|
||||
@@ -12,6 +13,7 @@ from .schemas import InstallPluginRequest, UninstallPluginRequest, UpdatePluginR
|
||||
from .support import (
|
||||
find_plugin_path_by_id,
|
||||
get_plugin_candidate_paths,
|
||||
get_plugin_config_path,
|
||||
iter_plugin_directories,
|
||||
load_manifest_json,
|
||||
parse_repository_url,
|
||||
@@ -64,6 +66,39 @@ def _infer_plugin_id(folder_name: str, manifest: Dict[str, Any], manifest_path:
|
||||
return plugin_id
|
||||
|
||||
|
||||
def _coerce_enabled_value(value: Any) -> bool:
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() not in {"false", "0", "no", "off", "disabled"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _read_plugin_enabled(plugin_id: str, plugin_path: Path) -> bool:
|
||||
try:
|
||||
config_path = get_plugin_config_path(plugin_id, plugin_path)
|
||||
if not config_path.exists():
|
||||
return True
|
||||
with open(config_path, "r", encoding="utf-8") as file_obj:
|
||||
config = tomlkit.load(file_obj).unwrap()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取插件 {plugin_id} 启用状态失败,将按启用处理: {exc}")
|
||||
return True
|
||||
|
||||
plugin_config = config.get("plugin") if isinstance(config, dict) else None
|
||||
if not isinstance(plugin_config, dict):
|
||||
return True
|
||||
return _coerce_enabled_value(plugin_config.get("enabled", True))
|
||||
|
||||
|
||||
def _get_runtime_plugin_load_statuses() -> Dict[str, str]:
|
||||
try:
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager().get_plugin_load_statuses()
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取插件运行时加载状态失败: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, maibot_session: Optional[str] = Cookie(None)) -> Dict[str, Any]:
|
||||
require_plugin_token(maibot_session)
|
||||
@@ -401,6 +436,7 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
||||
|
||||
try:
|
||||
installed_plugins: List[Dict[str, Any]] = []
|
||||
runtime_statuses = _get_runtime_plugin_load_statuses()
|
||||
for plugin_path in iter_plugin_directories():
|
||||
folder_name = plugin_path.name
|
||||
if folder_name.startswith(".") or folder_name.startswith("__"):
|
||||
@@ -420,7 +456,19 @@ async def get_installed_plugins(maibot_session: Optional[str] = Cookie(None)) ->
|
||||
logger.warning(f"插件文件夹 {folder_name} 的 _manifest.json 格式无效,跳过")
|
||||
continue
|
||||
plugin_id = _infer_plugin_id(folder_name, manifest, manifest_path)
|
||||
installed_plugins.append({"id": plugin_id, "manifest": manifest, "path": str(plugin_path.absolute())})
|
||||
enabled = _read_plugin_enabled(plugin_id, plugin_path)
|
||||
load_status = runtime_statuses.get(plugin_id, "unknown")
|
||||
installed_plugins.append(
|
||||
{
|
||||
"id": plugin_id,
|
||||
"manifest": manifest,
|
||||
"path": str(plugin_path.absolute()),
|
||||
"enabled": enabled,
|
||||
"disabled": not enabled,
|
||||
"loaded": load_status == "success",
|
||||
"load_status": "disabled" if not enabled else load_status,
|
||||
}
|
||||
)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"插件 {folder_name} 的 _manifest.json 解析失败: {e}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -106,7 +106,7 @@ def validate_plugin_id(plugin_id: str) -> str:
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> Tuple[int, int, int]:
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|pre|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
parts.extend(["0"] * (3 - len(parts)))
|
||||
|
||||
@@ -31,19 +31,19 @@ def _is_forbidden_ip_address(address: ipaddress.IPv4Address | ipaddress.IPv6Addr
|
||||
address.is_loopback,
|
||||
address.is_link_local,
|
||||
address.is_multicast,
|
||||
address.is_private,
|
||||
address.is_reserved,
|
||||
address.is_unspecified,
|
||||
getattr(address, "is_site_local", False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("https",)) -> str:
|
||||
def validate_public_url(url: str, allowed_schemes: Iterable[str] = ("http", "https")) -> str:
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL 不能为空")
|
||||
|
||||
if "://" not in normalized_url:
|
||||
normalized_url = "http://" + normalized_url
|
||||
parsed = urlparse(normalized_url)
|
||||
allowed_scheme_set = {scheme.lower() for scheme in allowed_schemes}
|
||||
if parsed.scheme.lower() not in allowed_scheme_set:
|
||||
|
||||
Reference in New Issue
Block a user