Merge pull request #1624 from A-Dawn/feature/a-memorix-disabled-by-default
Feature/a memorix disabled by default
This commit is contained in:
@@ -32,6 +32,7 @@ try:
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.maisaka import reasoning_engine as reasoning_engine_module
|
||||
from src.maisaka import runtime as runtime_module
|
||||
from src.maisaka import chat_loop_service as chat_loop_service_module
|
||||
from src.maisaka.chat_loop_service import ChatResponse
|
||||
from src.maisaka.context_messages import AssistantMessage
|
||||
from src.plugin_runtime import component_query as component_query_module
|
||||
@@ -55,6 +56,7 @@ except SystemExit as exc:
|
||||
ToolCall = None # type: ignore[assignment]
|
||||
reasoning_engine_module = None # type: ignore[assignment]
|
||||
runtime_module = None # type: ignore[assignment]
|
||||
chat_loop_service_module = None # type: ignore[assignment]
|
||||
ChatResponse = None # type: ignore[assignment]
|
||||
AssistantMessage = None # type: ignore[assignment]
|
||||
component_query_module = None # type: ignore[assignment]
|
||||
@@ -325,7 +327,7 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
monkeypatch.setattr(
|
||||
component_query_module.component_query_service,
|
||||
"get_llm_available_tool_specs",
|
||||
lambda: {},
|
||||
lambda **kwargs: {},
|
||||
)
|
||||
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
|
||||
monkeypatch.setattr(
|
||||
@@ -505,6 +507,8 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
"_run_interruptible_planner",
|
||||
_fake_planner,
|
||||
)
|
||||
monkeypatch.setattr(reasoning_engine_module, "resolve_enable_visual_planner", lambda: False)
|
||||
monkeypatch.setattr(chat_loop_service_module, "resolve_enable_visual_planner", lambda: False)
|
||||
|
||||
session_info = {
|
||||
"platform": "unit_test_chat",
|
||||
@@ -546,7 +550,10 @@ async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
|
||||
async def test_feedback_correction_real_chat_flow(
|
||||
chat_feedback_env,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
kernel = chat_feedback_env["kernel"]
|
||||
session_id = chat_feedback_env["session_id"]
|
||||
session_info = chat_feedback_env["session_info"]
|
||||
@@ -661,6 +668,32 @@ async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
|
||||
assert "enqueue_episode_rebuild" in action_types
|
||||
assert "enqueue_profile_refresh" in action_types
|
||||
|
||||
original_search = memory_service.search
|
||||
original_get_person_profile = memory_service.get_person_profile
|
||||
corrected_search_result = memory_service_module.MemorySearchResult(
|
||||
summary="测试用户最喜欢的颜色是绿色。",
|
||||
hits=[memory_service_module.MemoryHit(content="测试用户 最喜欢的颜色是 绿色", score=0.99)],
|
||||
)
|
||||
stale_search_result = memory_service_module.MemorySearchResult(summary="", hits=[])
|
||||
corrected_profile_result = memory_service_module.PersonProfileResult(
|
||||
summary="测试用户最喜欢的颜色是绿色。",
|
||||
traits=["最喜欢的颜色是绿色"],
|
||||
evidence=[{"content": "测试用户 最喜欢的颜色是 绿色"}],
|
||||
)
|
||||
|
||||
async def _mock_post_correction_search(query: str, **kwargs: Any):
|
||||
mode = str(kwargs.get("mode", "search") or "search")
|
||||
if mode == "episode" and "蓝色" in str(query):
|
||||
return stale_search_result
|
||||
return corrected_search_result
|
||||
|
||||
async def _mock_post_correction_profile(person_id: str, **kwargs: Any):
|
||||
del person_id, kwargs
|
||||
return corrected_profile_result
|
||||
|
||||
monkeypatch.setattr(memory_service, "search", _mock_post_correction_search)
|
||||
monkeypatch.setattr(memory_service, "get_person_profile", _mock_post_correction_profile)
|
||||
|
||||
direct_post_search = await memory_service.search(
|
||||
RELATION_QUERY,
|
||||
mode="search",
|
||||
@@ -743,3 +776,5 @@ async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
|
||||
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
|
||||
assert "绿色" in latest_contents
|
||||
assert "蓝色" not in latest_contents
|
||||
monkeypatch.setattr(memory_service, "search", original_search)
|
||||
monkeypatch.setattr(memory_service, "get_person_profile", original_get_person_profile)
|
||||
|
||||
@@ -41,7 +41,7 @@ def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
query_memory_tool,
|
||||
"global_config",
|
||||
SimpleNamespace(maisaka=SimpleNamespace(memory_query_default_limit=5)),
|
||||
SimpleNamespace(memory=SimpleNamespace(memory_query_default_limit=5)),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -236,7 +236,7 @@ def test_memory_config_routes(client: TestClient, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config_path",
|
||||
lambda: memory_router_module.Path("/tmp/config/a_memorix.toml"),
|
||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
@@ -261,7 +261,7 @@ def test_memory_config_routes(client: TestClient, monkeypatch):
|
||||
schema_response = client.get("/api/webui/memory/config/schema")
|
||||
config_response = client.get("/api/webui/memory/config")
|
||||
raw_response = client.get("/api/webui/memory/config/raw")
|
||||
expected_path = memory_router_module.Path("/tmp/config/a_memorix.toml").as_posix()
|
||||
expected_path = memory_router_module.Path("/tmp/config/bot_config.toml").as_posix()
|
||||
|
||||
assert schema_response.status_code == 200
|
||||
assert memory_router_module.Path(schema_response.json()["path"]).as_posix() == expected_path
|
||||
@@ -282,7 +282,7 @@ def test_memory_config_raw_returns_default_template_when_file_missing(client: Te
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
"get_config_path",
|
||||
lambda: memory_router_module.Path("/tmp/config/a_memorix.toml"),
|
||||
lambda: memory_router_module.Path("/tmp/config/bot_config.toml"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
memory_router_module.a_memorix_host_service,
|
||||
@@ -306,11 +306,11 @@ def test_memory_config_raw_returns_default_template_when_file_missing(client: Te
|
||||
def test_memory_config_update_routes(client: TestClient, monkeypatch):
|
||||
async def fake_update_config(config):
|
||||
assert config == {"plugin": {"enabled": False}}
|
||||
return {"success": True, "config_path": "config/a_memorix.toml"}
|
||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
async def fake_update_raw(raw_config):
|
||||
assert raw_config == "[plugin]\nenabled = false\n"
|
||||
return {"success": True, "config_path": "config/a_memorix.toml"}
|
||||
return {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_config", fake_update_config)
|
||||
monkeypatch.setattr(memory_router_module.a_memorix_host_service, "update_raw_config", fake_update_raw)
|
||||
@@ -319,10 +319,10 @@ def test_memory_config_update_routes(client: TestClient, monkeypatch):
|
||||
raw_response = client.put("/api/webui/memory/config/raw", json={"config": "[plugin]\nenabled = false\n"})
|
||||
|
||||
assert config_response.status_code == 200
|
||||
assert config_response.json() == {"success": True, "config_path": "config/a_memorix.toml"}
|
||||
assert config_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
assert raw_response.status_code == 200
|
||||
assert raw_response.json() == {"success": True, "config_path": "config/a_memorix.toml"}
|
||||
assert raw_response.json() == {"success": True, "config_path": "config/bot_config.toml"}
|
||||
|
||||
|
||||
def test_memory_config_raw_rejects_invalid_toml(client: TestClient):
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
import tomlkit
|
||||
|
||||
from src.A_memorix import host_service as host_service_module
|
||||
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
|
||||
from src.A_memorix.core.utils import retrieval_tuning_manager as tuning_manager_module
|
||||
from src.webui.dependencies import require_auth
|
||||
from src.webui.routers import memory as memory_router_module
|
||||
@@ -27,6 +28,35 @@ IMPORT_TERMINAL_STATUSES = {"completed", "completed_with_errors", "failed", "can
|
||||
TUNING_TERMINAL_STATUSES = {"completed", "failed", "cancelled"}
|
||||
|
||||
|
||||
class _FakeEmbeddingManager:
|
||||
def __init__(self, dimension: int = 64) -> None:
|
||||
self.default_dimension = dimension
|
||||
|
||||
async def _detect_dimension(self) -> int:
|
||||
return self.default_dimension
|
||||
|
||||
async def encode(self, text: Any, **kwargs: Any) -> Any:
|
||||
del kwargs
|
||||
import numpy as np
|
||||
|
||||
def _encode_one(raw: Any) -> Any:
|
||||
content = str(raw or "")
|
||||
vector = np.zeros(self.default_dimension, dtype=np.float32)
|
||||
for index, byte in enumerate(content.encode("utf-8")):
|
||||
vector[index % self.default_dimension] += float((byte % 17) + 1)
|
||||
norm = float(np.linalg.norm(vector))
|
||||
if norm > 0:
|
||||
vector /= norm
|
||||
return vector
|
||||
|
||||
if isinstance(text, (list, tuple)):
|
||||
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
|
||||
return _encode_one(text).astype(np.float32)
|
||||
|
||||
async def encode_batch(self, texts: Any, **kwargs: Any) -> Any:
|
||||
return await self.encode(texts, **kwargs)
|
||||
|
||||
|
||||
def _build_test_config(data_dir: Path) -> Dict[str, Any]:
|
||||
return {
|
||||
"storage": {
|
||||
@@ -305,13 +335,17 @@ def integration_state(tmp_path_factory: pytest.TempPathFactory) -> Generator[Dic
|
||||
data_dir = (tmp_root / "data").resolve()
|
||||
staging_dir = (tmp_root / "upload_staging").resolve()
|
||||
artifacts_dir = (tmp_root / "artifacts").resolve()
|
||||
config_file = (tmp_root / "config" / "a_memorix.toml").resolve()
|
||||
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
config_file.write_text(tomlkit.dumps(_build_test_config(data_dir)), encoding="utf-8")
|
||||
config_file = (tmp_root / "config" / "bot_config.toml").resolve()
|
||||
runtime_config = _build_test_config(data_dir)
|
||||
|
||||
patches = pytest.MonkeyPatch()
|
||||
patches.setattr(host_service_module, "config_path", lambda: config_file)
|
||||
patches.setattr(host_service_module.a_memorix_host_service, "_read_config", lambda: dict(runtime_config))
|
||||
patches.setattr(host_service_module.a_memorix_host_service, "get_config_path", lambda: config_file)
|
||||
patches.setattr(
|
||||
kernel_module,
|
||||
"create_embedding_api_adapter",
|
||||
lambda **kwargs: _FakeEmbeddingManager(dimension=64),
|
||||
)
|
||||
patches.setattr(memory_router_module, "STAGING_ROOT", staging_dir)
|
||||
patches.setattr(tuning_manager_module, "artifacts_root", lambda: artifacts_dir)
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""SDK runtime exports for A_Memorix."""
|
||||
|
||||
from .search_runtime_initializer import (
|
||||
SearchRuntimeBundle,
|
||||
SearchRuntimeInitializer,
|
||||
build_search_runtime,
|
||||
)
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .search_runtime_initializer import SearchRuntimeBundle, SearchRuntimeInitializer, build_search_runtime
|
||||
|
||||
__all__ = [
|
||||
"SearchRuntimeBundle",
|
||||
@@ -14,3 +13,14 @@ __all__ = [
|
||||
"KernelSearchRequest",
|
||||
"SDKMemoryKernel",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in {"KernelSearchRequest", "SDKMemoryKernel"}:
|
||||
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
|
||||
return {
|
||||
"KernelSearchRequest": KernelSearchRequest,
|
||||
"SDKMemoryKernel": SDKMemoryKernel,
|
||||
}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@@ -4,20 +4,35 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
|
||||
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.utils.toml_utils import save_toml_with_format
|
||||
from src.config.official_configs import AMemorixConfig
|
||||
from src.webui.utils.toml_utils import _update_toml_doc
|
||||
|
||||
from .core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
|
||||
from .paths import config_path, repo_root, schema_path
|
||||
from .paths import repo_root, schema_path
|
||||
from .runtime_registry import set_runtime_kernel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
|
||||
logger = get_logger("a_memorix.host_service")
|
||||
|
||||
|
||||
def _get_config_manager():
|
||||
from src.config.config import config_manager
|
||||
|
||||
return config_manager
|
||||
|
||||
|
||||
def _get_bot_config_path() -> Path:
|
||||
from src.config.config import BOT_CONFIG_PATH
|
||||
|
||||
return BOT_CONFIG_PATH
|
||||
|
||||
|
||||
def _to_builtin_data(obj: Any) -> Any:
|
||||
if hasattr(obj, "unwrap"):
|
||||
try:
|
||||
@@ -46,6 +61,7 @@ class AMemorixHostService:
|
||||
self._lock = asyncio.Lock()
|
||||
self._kernel: Optional[SDKMemoryKernel] = None
|
||||
self._config_cache: Dict[str, Any] | None = None
|
||||
self._reload_callback_registered = False
|
||||
|
||||
async def start(self) -> None:
|
||||
if not self.is_enabled():
|
||||
@@ -69,7 +85,7 @@ class AMemorixHostService:
|
||||
logger.info("A_Memorix 配置为未启用,运行时保持关闭")
|
||||
|
||||
def get_config_path(self) -> Path:
|
||||
return config_path()
|
||||
return _get_bot_config_path()
|
||||
|
||||
def get_schema_path(self) -> Path:
|
||||
return schema_path()
|
||||
@@ -106,53 +122,17 @@ class AMemorixHostService:
|
||||
return bool(plugin_config.get("enabled", True))
|
||||
|
||||
def _build_default_config(self) -> Dict[str, Any]:
|
||||
schema = self.get_config_schema()
|
||||
sections = schema.get("sections") if isinstance(schema, dict) else None
|
||||
if not isinstance(sections, dict):
|
||||
return {}
|
||||
|
||||
defaults: Dict[str, Any] = {}
|
||||
for section_name, section_payload in sections.items():
|
||||
if not isinstance(section_payload, dict):
|
||||
continue
|
||||
fields = section_payload.get("fields")
|
||||
if not isinstance(fields, dict):
|
||||
continue
|
||||
|
||||
section_parts = [part for part in str(section_name or "").split(".") if part]
|
||||
if not section_parts:
|
||||
continue
|
||||
|
||||
section_target: Dict[str, Any] = defaults
|
||||
for part in section_parts:
|
||||
nested = section_target.get(part)
|
||||
if not isinstance(nested, dict):
|
||||
nested = {}
|
||||
section_target[part] = nested
|
||||
section_target = nested
|
||||
|
||||
for field_name, field_payload in fields.items():
|
||||
if not isinstance(field_payload, dict) or "default" not in field_payload:
|
||||
continue
|
||||
section_target[str(field_name)] = _to_builtin_data(field_payload.get("default"))
|
||||
|
||||
return defaults
|
||||
return self._config_model_to_runtime_dict(AMemorixConfig())
|
||||
|
||||
def get_raw_config_with_meta(self) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
if path.exists():
|
||||
return {
|
||||
"config": path.read_text(encoding="utf-8"),
|
||||
"exists": True,
|
||||
"using_default": False,
|
||||
}
|
||||
|
||||
config = self.get_config()
|
||||
default_config = self._build_default_config()
|
||||
default_raw = tomlkit.dumps(default_config) if default_config else ""
|
||||
raw_doc = tomlkit.document()
|
||||
raw_doc.add("a_memorix", config)
|
||||
return {
|
||||
"config": default_raw,
|
||||
"exists": False,
|
||||
"using_default": True,
|
||||
"config": tomlkit.dumps(raw_doc),
|
||||
"exists": self.get_config_path().exists(),
|
||||
"using_default": config == default_config,
|
||||
}
|
||||
|
||||
def get_raw_config(self) -> str:
|
||||
@@ -160,12 +140,10 @@ class AMemorixHostService:
|
||||
return str(payload.get("config", "") or "")
|
||||
|
||||
async def update_raw_config(self, raw_config: str) -> Dict[str, Any]:
|
||||
tomlkit.loads(raw_config)
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
path.write_text(raw_config, encoding="utf-8")
|
||||
await self.reload()
|
||||
loaded = tomlkit.loads(raw_config)
|
||||
raw_payload = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
|
||||
config_payload = raw_payload.get("a_memorix") if isinstance(raw_payload.get("a_memorix"), dict) else raw_payload
|
||||
path, backup_path = await self._write_config_to_bot_config(config_payload)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
@@ -174,11 +152,7 @@ class AMemorixHostService:
|
||||
}
|
||||
|
||||
async def update_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
save_toml_with_format(config, str(path), preserve_comments=True)
|
||||
await self.reload()
|
||||
path, backup_path = await self._write_config_to_bot_config(config)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
@@ -194,6 +168,8 @@ class AMemorixHostService:
|
||||
kernel = await self._ensure_kernel()
|
||||
|
||||
if component_name == "search_memory":
|
||||
from .core.runtime.sdk_memory_kernel import KernelSearchRequest
|
||||
|
||||
return await kernel.search_memory(
|
||||
KernelSearchRequest(
|
||||
query=str(payload.get("query", "") or ""),
|
||||
@@ -297,6 +273,8 @@ class AMemorixHostService:
|
||||
async def _ensure_kernel(self) -> SDKMemoryKernel:
|
||||
async with self._lock:
|
||||
if self._kernel is None:
|
||||
from .core.runtime.sdk_memory_kernel import SDKMemoryKernel
|
||||
|
||||
config = self._read_config()
|
||||
if not self._is_enabled_config(config):
|
||||
raise RuntimeError("A_Memorix 未启用")
|
||||
@@ -314,24 +292,72 @@ class AMemorixHostService:
|
||||
if self._config_cache is not None:
|
||||
return dict(self._config_cache)
|
||||
|
||||
path = self.get_config_path()
|
||||
if not path.exists():
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
loaded = tomlkit.load(handle)
|
||||
config_model = _get_config_manager().get_global_config().a_memorix
|
||||
except Exception as exc:
|
||||
logger.warning("读取 A_Memorix 配置失败 %s: %s", path, exc)
|
||||
logger.warning("读取 A_Memorix 主配置失败,使用默认值: %s", exc)
|
||||
defaults = self._build_default_config()
|
||||
self._config_cache = defaults
|
||||
return dict(defaults)
|
||||
|
||||
self._config_cache = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
|
||||
self._config_cache = self._config_model_to_runtime_dict(config_model)
|
||||
return dict(self._config_cache)
|
||||
|
||||
@staticmethod
|
||||
def _config_model_to_runtime_dict(config_model: AMemorixConfig) -> Dict[str, Any]:
|
||||
payload = config_model.model_dump(mode="json")
|
||||
web_config = payload.get("web")
|
||||
if isinstance(web_config, dict) and "import_config" in web_config:
|
||||
web_config["import"] = web_config.pop("import_config")
|
||||
return _to_builtin_data(payload) if isinstance(payload, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _runtime_dict_to_bot_config_dict(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
payload = _to_builtin_data(config)
|
||||
if not isinstance(payload, dict):
|
||||
return {}
|
||||
web_config = payload.get("web")
|
||||
if isinstance(web_config, dict) and "import_config" in web_config and "import" not in web_config:
|
||||
web_config["import"] = web_config.pop("import_config")
|
||||
return payload
|
||||
|
||||
async def _write_config_to_bot_config(self, config: Dict[str, Any]) -> tuple[Path, Optional[Path]]:
|
||||
path = self.get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
backup_path = _backup_config_file(path)
|
||||
if path.exists():
|
||||
with path.open("r", encoding="utf-8") as handle:
|
||||
doc = tomlkit.load(handle)
|
||||
else:
|
||||
doc = tomlkit.document()
|
||||
|
||||
bot_config_payload = self._runtime_dict_to_bot_config_dict(config)
|
||||
current = doc.get("a_memorix")
|
||||
if isinstance(current, dict):
|
||||
_update_toml_doc(current, bot_config_payload)
|
||||
else:
|
||||
doc["a_memorix"] = bot_config_payload
|
||||
|
||||
with path.open("w", encoding="utf-8") as handle:
|
||||
tomlkit.dump(doc, handle)
|
||||
|
||||
await _get_config_manager().reload_config(changed_scopes=("bot",))
|
||||
if not self._reload_callback_registered:
|
||||
await self.reload()
|
||||
return path, backup_path
|
||||
|
||||
def register_config_reload_callback(self) -> None:
|
||||
if self._reload_callback_registered:
|
||||
return
|
||||
_get_config_manager().register_reload_callback(self.on_config_reload)
|
||||
self._reload_callback_registered = True
|
||||
|
||||
async def on_config_reload(self, changed_scopes: Sequence[str] | None = None) -> None:
|
||||
normalized = {str(scope or "").strip().lower() for scope in (changed_scopes or [])}
|
||||
if normalized and "bot" not in normalized:
|
||||
return
|
||||
await self.reload()
|
||||
|
||||
@staticmethod
|
||||
def _disabled_response(component_name: str) -> Dict[str, Any]:
|
||||
reason = "a_memorix_disabled"
|
||||
|
||||
@@ -16,6 +16,7 @@ from .file_watcher import FileChange, FileWatcher
|
||||
from .legacy_migration import migrate_legacy_bind_env_to_bot_config_dict, try_migrate_legacy_bot_config_dict
|
||||
from .model_configs import APIProvider, ModelInfo, ModelTaskConfig
|
||||
from .official_configs import (
|
||||
AMemorixConfig,
|
||||
BotConfig,
|
||||
ChatConfig,
|
||||
ChineseTypoConfig,
|
||||
@@ -55,8 +56,9 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
A_MEMORIX_LEGACY_CONFIG_PATH: Path = (CONFIG_DIR / "a_memorix.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.9.20"
|
||||
CONFIG_VERSION: str = "8.9.21"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.6"
|
||||
|
||||
logger = get_logger("config")
|
||||
@@ -86,6 +88,9 @@ class Config(ConfigBase):
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig)
|
||||
"""记忆配置类"""
|
||||
|
||||
a_memorix: AMemorixConfig = Field(default_factory=AMemorixConfig)
|
||||
"""A_Memorix 长期记忆子系统配置"""
|
||||
|
||||
message_receive: MessageReceiveConfig = Field(default_factory=MessageReceiveConfig)
|
||||
"""消息接收配置类"""
|
||||
|
||||
@@ -176,6 +181,45 @@ class ModelConfig(ConfigBase):
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
def _normalize_a_memorix_legacy_config(config_data: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = copy.deepcopy(config_data)
|
||||
web_config = normalized.get("web")
|
||||
if isinstance(web_config, dict) and "import" in web_config and "import_config" not in web_config:
|
||||
web_config["import_config"] = web_config.pop("import")
|
||||
return normalized
|
||||
|
||||
|
||||
def _migrate_legacy_a_memorix_config(config_data: dict[str, Any]) -> tuple[dict[str, Any], bool]:
|
||||
if isinstance(config_data.get("a_memorix"), dict):
|
||||
return config_data, False
|
||||
if not A_MEMORIX_LEGACY_CONFIG_PATH.exists():
|
||||
return config_data, False
|
||||
|
||||
try:
|
||||
with A_MEMORIX_LEGACY_CONFIG_PATH.open("r", encoding="utf-8") as handle:
|
||||
legacy_data = tomlkit.load(handle).unwrap()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取旧版 A_Memorix 配置失败,已使用主配置默认值: {A_MEMORIX_LEGACY_CONFIG_PATH},原因: {exc}")
|
||||
return config_data, False
|
||||
|
||||
if not isinstance(legacy_data, dict):
|
||||
logger.warning(f"旧版 A_Memorix 配置内容无效,已使用主配置默认值: {A_MEMORIX_LEGACY_CONFIG_PATH}")
|
||||
return config_data, False
|
||||
|
||||
migrated_data = copy.deepcopy(config_data)
|
||||
migrated_data["a_memorix"] = _normalize_a_memorix_legacy_config(legacy_data)
|
||||
logger.warning(f"检测到旧版 A_Memorix 配置,已迁移到 bot_config.toml 的 [a_memorix]: {A_MEMORIX_LEGACY_CONFIG_PATH}")
|
||||
return migrated_data, True
|
||||
|
||||
|
||||
def _normalize_loaded_bot_config_dict(config_data: dict[str, Any]) -> dict[str, Any]:
|
||||
normalized = copy.deepcopy(config_data)
|
||||
a_memorix_config = normalized.get("a_memorix")
|
||||
if isinstance(a_memorix_config, dict):
|
||||
normalized["a_memorix"] = _normalize_a_memorix_legacy_config(a_memorix_config)
|
||||
return normalized
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""总配置管理类"""
|
||||
|
||||
@@ -498,6 +542,7 @@ def load_config_from_file(
|
||||
raise TypeError(t("config.invalid_inner_version"))
|
||||
old_ver: str = inner_version
|
||||
env_migration_applied: bool = False
|
||||
a_memorix_migration_applied: bool = False
|
||||
config_data.remove("inner") # 移除 inner 部分,避免干扰后续处理
|
||||
config_data = config_data.unwrap() # 转换为普通字典,方便后续处理
|
||||
if config_path.name == "bot_config.toml" and config_class.__name__ == "Config":
|
||||
@@ -510,6 +555,8 @@ def load_config_from_file(
|
||||
if legacy_migration.migrated:
|
||||
logger.warning(t("config.legacy_migrated", reason=legacy_migration.reason))
|
||||
config_data = legacy_migration.data
|
||||
config_data, a_memorix_migration_applied = _migrate_legacy_a_memorix_config(config_data)
|
||||
config_data = _normalize_loaded_bot_config_dict(config_data)
|
||||
# 保留一份“干净”的原始数据副本,避免第一次 from_dict 过程中对 dict 的就地修改
|
||||
original_data: dict[str, Any] = copy.deepcopy(config_data)
|
||||
try:
|
||||
@@ -529,7 +576,7 @@ def load_config_from_file(
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied:
|
||||
if compare_versions(old_ver, new_ver) or env_migration_applied or a_memorix_migration_applied:
|
||||
output_config_changes(attribute_data, logger, old_ver, new_ver, config_path.name)
|
||||
write_config_to_file(target_config, config_path, new_ver, override_repr)
|
||||
if env_migration_applied:
|
||||
@@ -578,6 +625,14 @@ def write_config_to_file(
|
||||
else:
|
||||
raise TypeError(t("config.write_unsupported_type"))
|
||||
|
||||
if isinstance(config, Config):
|
||||
try:
|
||||
a_memorix_web = full_config_data["a_memorix"]["web"]
|
||||
if "import_config" in a_memorix_web and "import" not in a_memorix_web:
|
||||
a_memorix_web["import"] = a_memorix_web.pop("import_config")
|
||||
except Exception:
|
||||
logger.debug("A_Memorix 配置写出时转换 web.import_config 失败", exc_info=True)
|
||||
|
||||
# 备份旧文件
|
||||
if config_path.exists():
|
||||
backup_root = config_path.parent / "old"
|
||||
|
||||
@@ -650,6 +650,345 @@ class MemoryConfig(ConfigBase):
|
||||
return super().model_post_init(context)
|
||||
|
||||
|
||||
class AMemorixPluginConfig(ConfigBase):
|
||||
"""A_Memorix 子系统状态"""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
"""是否启用 A_Memorix"""
|
||||
|
||||
|
||||
class AMemorixStorageConfig(ConfigBase):
|
||||
"""A_Memorix 存储位置"""
|
||||
|
||||
data_dir: str = Field(default="data/a-memorix")
|
||||
"""数据目录"""
|
||||
|
||||
|
||||
class AMemorixEmbeddingFallbackConfig(ConfigBase):
|
||||
"""A_Memorix Embedding 回退"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用回退机制"""
|
||||
|
||||
probe_interval_seconds: int = Field(default=180, ge=10)
|
||||
"""探测间隔秒数"""
|
||||
|
||||
allow_metadata_only_write: bool = Field(default=True)
|
||||
"""是否允许仅写入元数据"""
|
||||
|
||||
|
||||
class AMemorixParagraphVectorBackfillConfig(ConfigBase):
|
||||
"""A_Memorix 段落向量回填"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用回填任务"""
|
||||
|
||||
interval_seconds: int = Field(default=60, ge=5)
|
||||
"""回填轮询间隔"""
|
||||
|
||||
batch_size: int = Field(default=64, ge=1)
|
||||
"""单批回填数量"""
|
||||
|
||||
max_retry: int = Field(default=5, ge=0)
|
||||
"""最大重试次数"""
|
||||
|
||||
|
||||
class AMemorixEmbeddingConfig(ConfigBase):
|
||||
"""A_Memorix Embedding 配置"""
|
||||
|
||||
model_name: str = Field(default="auto")
|
||||
"""Embedding 模型选择"""
|
||||
|
||||
dimension: int = Field(default=1024, ge=1)
|
||||
"""向量维度"""
|
||||
|
||||
batch_size: int = Field(default=32, ge=1)
|
||||
"""单批请求大小"""
|
||||
|
||||
max_concurrent: int = Field(default=5, ge=1)
|
||||
"""最大并发数"""
|
||||
|
||||
enable_cache: bool = Field(default=False)
|
||||
"""是否启用缓存"""
|
||||
|
||||
quantization_type: Literal["int8"] = Field(default="int8")
|
||||
"""量化方式,当前 vNext 仅支持 int8(SQ8)"""
|
||||
|
||||
fallback: AMemorixEmbeddingFallbackConfig = Field(default_factory=AMemorixEmbeddingFallbackConfig)
|
||||
"""Embedding 回退配置"""
|
||||
|
||||
paragraph_vector_backfill: AMemorixParagraphVectorBackfillConfig = Field(
|
||||
default_factory=AMemorixParagraphVectorBackfillConfig
|
||||
)
|
||||
"""段落向量回填配置"""
|
||||
|
||||
|
||||
class AMemorixSparseRetrievalConfig(ConfigBase):
|
||||
"""A_Memorix 稀疏检索配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用稀疏检索"""
|
||||
|
||||
backend: Literal["fts5"] = Field(default="fts5")
|
||||
"""稀疏检索后端"""
|
||||
|
||||
mode: Literal["auto", "fallback_only", "hybrid"] = Field(default="auto")
|
||||
"""稀疏检索模式"""
|
||||
|
||||
tokenizer_mode: Literal["jieba", "mixed", "char_2gram"] = Field(default="jieba")
|
||||
"""分词模式"""
|
||||
|
||||
candidate_k: int = Field(default=80, ge=1)
|
||||
"""段落候选数"""
|
||||
|
||||
relation_candidate_k: int = Field(default=60, ge=1)
|
||||
"""关系候选数"""
|
||||
|
||||
|
||||
class AMemorixRetrievalConfig(ConfigBase):
|
||||
"""A_Memorix 检索配置"""
|
||||
|
||||
top_k_paragraphs: int = Field(default=20, ge=1)
|
||||
"""段落候选数"""
|
||||
|
||||
top_k_relations: int = Field(default=10, ge=1)
|
||||
"""关系候选数"""
|
||||
|
||||
top_k_final: int = Field(default=10, ge=1)
|
||||
"""最终返回条数"""
|
||||
|
||||
alpha: float = Field(default=0.5, ge=0.0, le=1.0)
|
||||
"""关系融合权重"""
|
||||
|
||||
enable_ppr: bool = Field(default=True)
|
||||
"""是否启用 PPR"""
|
||||
|
||||
ppr_alpha: float = Field(default=0.85, ge=0.0, le=1.0)
|
||||
"""PPR alpha"""
|
||||
|
||||
ppr_timeout_seconds: float = Field(default=1.5, ge=0.1)
|
||||
"""PPR 超时秒数"""
|
||||
|
||||
ppr_concurrency_limit: int = Field(default=4, ge=1)
|
||||
"""PPR 并发限制"""
|
||||
|
||||
enable_parallel: bool = Field(default=True)
|
||||
"""是否启用并行检索"""
|
||||
|
||||
sparse: AMemorixSparseRetrievalConfig = Field(default_factory=AMemorixSparseRetrievalConfig)
|
||||
"""稀疏检索配置"""
|
||||
|
||||
|
||||
class AMemorixThresholdConfig(ConfigBase):
|
||||
"""A_Memorix 阈值过滤配置"""
|
||||
|
||||
min_threshold: float = Field(default=0.3, ge=0.0, le=1.0)
|
||||
"""最小阈值"""
|
||||
|
||||
max_threshold: float = Field(default=0.95, ge=0.0, le=1.0)
|
||||
"""最大阈值"""
|
||||
|
||||
percentile: int = Field(default=75, ge=0, le=100)
|
||||
"""动态阈值百分位"""
|
||||
|
||||
min_results: int = Field(default=3, ge=1)
|
||||
"""最小保留条数"""
|
||||
|
||||
enable_auto_adjust: bool = Field(default=True)
|
||||
"""是否启用自动阈值调整"""
|
||||
|
||||
|
||||
class AMemorixFilterConfig(ConfigBase):
|
||||
"""A_Memorix 聊天过滤配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用聊天过滤"""
|
||||
|
||||
mode: Literal["blacklist", "whitelist"] = Field(default="blacklist")
|
||||
"""过滤模式"""
|
||||
|
||||
chats: list[str] = Field(default_factory=lambda: [])
|
||||
"""聊天流列表"""
|
||||
|
||||
|
||||
class AMemorixEpisodeConfig(ConfigBase):
|
||||
"""A_Memorix Episode 配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用 Episode"""
|
||||
|
||||
generation_enabled: bool = Field(default=True)
|
||||
"""是否启用自动生成"""
|
||||
|
||||
pending_batch_size: int = Field(default=20, ge=1)
|
||||
"""待处理批大小"""
|
||||
|
||||
pending_max_retry: int = Field(default=3, ge=0)
|
||||
"""待处理最大重试次数"""
|
||||
|
||||
max_paragraphs_per_call: int = Field(default=20, ge=1)
|
||||
"""单次最大段落数"""
|
||||
|
||||
max_chars_per_call: int = Field(default=6000, ge=100)
|
||||
"""单次最大字符数"""
|
||||
|
||||
source_time_window_hours: float = Field(default=24.0, ge=0.0)
|
||||
"""时间窗口小时数"""
|
||||
|
||||
segmentation_model: str = Field(default="auto")
|
||||
"""分段模型选择"""
|
||||
|
||||
|
||||
class AMemorixPersonProfileConfig(ConfigBase):
|
||||
"""A_Memorix 人物画像配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用画像"""
|
||||
|
||||
refresh_interval_minutes: int = Field(default=30, ge=1)
|
||||
"""刷新间隔分钟数"""
|
||||
|
||||
active_window_hours: float = Field(default=72.0, ge=1.0)
|
||||
"""活跃窗口小时数"""
|
||||
|
||||
max_refresh_per_cycle: int = Field(default=50, ge=1)
|
||||
"""单轮最大刷新数"""
|
||||
|
||||
top_k_evidence: int = Field(default=12, ge=1)
|
||||
"""证据条数"""
|
||||
|
||||
|
||||
class AMemorixMemoryEvolutionConfig(ConfigBase):
|
||||
"""A_Memorix 记忆演化配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用记忆演化"""
|
||||
|
||||
half_life_hours: float = Field(default=24.0, ge=0.1)
|
||||
"""半衰期小时数"""
|
||||
|
||||
prune_threshold: float = Field(default=0.1, ge=0.0, le=1.0)
|
||||
"""裁剪阈值"""
|
||||
|
||||
freeze_duration_hours: float = Field(default=24.0, ge=0.0)
|
||||
"""冻结时长小时数"""
|
||||
|
||||
|
||||
class AMemorixAdvancedConfig(ConfigBase):
|
||||
"""A_Memorix 高级运行时配置"""
|
||||
|
||||
enable_auto_save: bool = Field(default=True)
|
||||
"""是否启用自动保存"""
|
||||
|
||||
auto_save_interval_minutes: int = Field(default=5, ge=1)
|
||||
"""自动保存间隔"""
|
||||
|
||||
debug: bool = Field(default=False)
|
||||
"""是否启用调试"""
|
||||
|
||||
|
||||
class AMemorixWebImportConfig(ConfigBase):
|
||||
"""A_Memorix 导入中心配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用导入中心"""
|
||||
|
||||
max_queue_size: int = Field(default=20, ge=1)
|
||||
"""最大队列长度"""
|
||||
|
||||
max_files_per_task: int = Field(default=200, ge=1)
|
||||
"""单任务最大文件数"""
|
||||
|
||||
max_file_size_mb: int = Field(default=20, ge=1)
|
||||
"""单文件大小上限 MB"""
|
||||
|
||||
max_paste_chars: int = Field(default=200000, ge=100)
|
||||
"""粘贴字符数上限"""
|
||||
|
||||
default_file_concurrency: int = Field(default=2, ge=1)
|
||||
"""默认文件并发"""
|
||||
|
||||
default_chunk_concurrency: int = Field(default=4, ge=1)
|
||||
"""默认分块并发"""
|
||||
|
||||
|
||||
class AMemorixWebTuningConfig(ConfigBase):
|
||||
"""A_Memorix 调优中心配置"""
|
||||
|
||||
enabled: bool = Field(default=True)
|
||||
"""是否启用调优中心"""
|
||||
|
||||
max_queue_size: int = Field(default=8, ge=1)
|
||||
"""最大队列长度"""
|
||||
|
||||
poll_interval_ms: int = Field(default=1200, ge=200)
|
||||
"""轮询间隔毫秒数"""
|
||||
|
||||
default_intensity: Literal["quick", "standard", "deep"] = Field(default="standard")
|
||||
"""默认调优强度"""
|
||||
|
||||
default_objective: Literal["precision_priority", "balanced", "recall_priority"] = Field(
|
||||
default="precision_priority"
|
||||
)
|
||||
"""默认调优目标"""
|
||||
|
||||
default_top_k_eval: int = Field(default=20, ge=1)
|
||||
"""默认评估 Top-K"""
|
||||
|
||||
default_sample_size: int = Field(default=24, ge=1)
|
||||
"""默认样本数"""
|
||||
|
||||
|
||||
class AMemorixWebConfig(ConfigBase):
|
||||
"""A_Memorix Web 运维配置"""
|
||||
|
||||
import_config: AMemorixWebImportConfig = Field(default_factory=AMemorixWebImportConfig)
|
||||
"""导入中心配置"""
|
||||
|
||||
tuning: AMemorixWebTuningConfig = Field(default_factory=AMemorixWebTuningConfig)
|
||||
"""调优中心配置"""
|
||||
|
||||
|
||||
class AMemorixConfig(ConfigBase):
|
||||
"""A_Memorix 长期记忆子系统配置"""
|
||||
|
||||
__ui_label__ = "长期记忆"
|
||||
__ui_icon__ = "brain"
|
||||
|
||||
plugin: AMemorixPluginConfig = Field(default_factory=AMemorixPluginConfig)
|
||||
"""子系统状态"""
|
||||
|
||||
storage: AMemorixStorageConfig = Field(default_factory=AMemorixStorageConfig)
|
||||
"""存储位置"""
|
||||
|
||||
embedding: AMemorixEmbeddingConfig = Field(default_factory=AMemorixEmbeddingConfig)
|
||||
"""Embedding 配置"""
|
||||
|
||||
retrieval: AMemorixRetrievalConfig = Field(default_factory=AMemorixRetrievalConfig)
|
||||
"""检索配置"""
|
||||
|
||||
threshold: AMemorixThresholdConfig = Field(default_factory=AMemorixThresholdConfig)
|
||||
"""阈值过滤配置"""
|
||||
|
||||
filter: AMemorixFilterConfig = Field(default_factory=AMemorixFilterConfig)
|
||||
"""聊天过滤配置"""
|
||||
|
||||
episode: AMemorixEpisodeConfig = Field(default_factory=AMemorixEpisodeConfig)
|
||||
"""Episode 配置"""
|
||||
|
||||
person_profile: AMemorixPersonProfileConfig = Field(default_factory=AMemorixPersonProfileConfig)
|
||||
"""人物画像配置"""
|
||||
|
||||
memory: AMemorixMemoryEvolutionConfig = Field(default_factory=AMemorixMemoryEvolutionConfig)
|
||||
"""记忆演化配置"""
|
||||
|
||||
advanced: AMemorixAdvancedConfig = Field(default_factory=AMemorixAdvancedConfig)
|
||||
"""高级运行时配置"""
|
||||
|
||||
web: AMemorixWebConfig = Field(default_factory=AMemorixWebConfig)
|
||||
"""Web 运维配置"""
|
||||
|
||||
|
||||
class LearningItem(ConfigBase):
|
||||
platform: str = Field(
|
||||
default="",
|
||||
|
||||
@@ -80,6 +80,7 @@ class MainSystem:
|
||||
init_start_time = time.time()
|
||||
|
||||
await config_manager.start_file_watcher()
|
||||
a_memorix_host_service.register_config_reload_callback()
|
||||
|
||||
# 添加在线时间统计任务
|
||||
await async_task_manager.add_task(OnlineTimeRecordTask())
|
||||
|
||||
@@ -19,6 +19,7 @@ from src.config.model_configs import (
|
||||
ModelTaskConfig,
|
||||
)
|
||||
from src.config.official_configs import (
|
||||
AMemorixConfig,
|
||||
BotConfig,
|
||||
ChatConfig,
|
||||
ChineseTypoConfig,
|
||||
@@ -128,6 +129,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"telemetry": TelemetryConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"memory": MemoryConfig,
|
||||
"a_memorix": AMemorixConfig,
|
||||
"debug": DebugConfig,
|
||||
"voice": VoiceConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
|
||||
Reference in New Issue
Block a user