feat:完善长期记忆控制台导入链路与联调测试

summary:\n- 扩展长期记忆控制台导入、调优与删除相关 UI/接口,补充中文化展示与任务细粒度状态管理\n- 强化 memory API 与后端路由能力,补齐导入任务、图谱检索、配置与运行态相关字段\n- 新增与增强前后端测试,覆盖导入多文件类型、检索、调优、删除及图谱查询关键路径

description:\n- dashboard: 重构 knowledge-base 页面与 memory-api,统一任务队列、分块分页、来源删除恢复、调优闭环交互\n- backend: 扩展 webui memory 路由与 A_Memorix 内核检索逻辑,完善服务侧能力与配置 schema\n- tests: 增加 webui 集成测试和 kernel 单测,提升导入/检索/调优/删除全流程回归保障
This commit is contained in:
DawnARC
2026-04-03 19:50:08 +08:00
parent eac5495d00
commit da95b06f96
18 changed files with 4045 additions and 299 deletions

View File

@@ -98,7 +98,7 @@
"data_dir": {
"name": "data_dir",
"type": "string",
"default": "data/plugins/a-dawn.a-memorix",
"default": "data/a-memorix",
"description": "数据目录",
"label": "数据目录",
"ui_type": "text",
@@ -107,7 +107,7 @@
"disabled": false,
"order": 1,
"hint": "相对路径按 MaiBot 仓库根目录解析,建议保持默认外置目录。",
"placeholder": "data/plugins/a-dawn.a-memorix",
"placeholder": "data/a-memorix",
"choices": null
}
}

View File

@@ -1317,6 +1317,11 @@ class SDKMemoryKernel:
act = str(action or "").strip().lower()
if act == "get_graph":
return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))}
if act == "search":
return self._search_graph(
query=str(kwargs.get("query", "") or "").strip(),
limit=max(1, min(200, int(kwargs.get("limit", 50) or 50))),
)
if act == "node_detail":
detail = self._build_graph_node_detail(
node_id=str(kwargs.get("node_id", "") or kwargs.get("node", "") or "").strip(),
@@ -2275,6 +2280,179 @@ class SDKMemoryKernel:
"total_edges": int(self.graph_store.num_edges),
}
@staticmethod
def _graph_search_match_rank(value: str, keyword: str) -> Optional[int]:
token = str(value or "").strip().lower()
if not token or not keyword:
return None
if token == keyword:
return 0
if token.startswith(keyword):
return 1
if keyword in token:
return 2
return None
@classmethod
def _pick_graph_search_match(
cls,
fields: Sequence[tuple[str, str]],
keyword: str,
) -> Optional[tuple[str, str, int]]:
best_match: Optional[tuple[str, str, int]] = None
for field, raw_value in fields:
value = str(raw_value or "").strip()
if not value:
continue
rank = cls._graph_search_match_rank(value, keyword)
if rank is None:
continue
if best_match is None or rank < best_match[2]:
best_match = (field, value, rank)
return best_match
def _search_graph(self, *, query: str, limit: int) -> Dict[str, Any]:
assert self.metadata_store is not None
token = str(query or "").strip()
normalized_query = token.lower()
safe_limit = max(1, int(limit or 50))
if not token:
return {
"success": False,
"query": token,
"limit": safe_limit,
"count": 0,
"items": [],
"error": "query 不能为空",
}
like_keyword = f"%{normalized_query}%"
entity_rows = self.metadata_store.query(
"""
SELECT hash, name, appearance_count, created_at
FROM entities
WHERE (is_deleted IS NULL OR is_deleted = 0)
AND (
LOWER(COALESCE(name, '')) LIKE ?
OR LOWER(COALESCE(hash, '')) LIKE ?
)
""",
(like_keyword, like_keyword),
)
relation_rows = self.metadata_store.query(
"""
SELECT hash, subject, predicate, object, confidence, created_at
FROM relations
WHERE (is_inactive IS NULL OR is_inactive = 0)
AND (
LOWER(COALESCE(subject, '')) LIKE ?
OR LOWER(COALESCE(object, '')) LIKE ?
OR LOWER(COALESCE(predicate, '')) LIKE ?
OR LOWER(COALESCE(hash, '')) LIKE ?
)
""",
(like_keyword, like_keyword, like_keyword, like_keyword),
)
entity_items: List[Dict[str, Any]] = []
seen_entity_keys: set[str] = set()
for row in entity_rows:
name = str(row.get("name", "") or "").strip()
hash_value = str(row.get("hash", "") or "").strip()
match = self._pick_graph_search_match(
[("name", name), ("hash", hash_value)],
normalized_query,
)
if match is None:
continue
dedupe_key = hash_value or f"name:{name.lower()}"
if dedupe_key in seen_entity_keys:
continue
seen_entity_keys.add(dedupe_key)
matched_field, matched_value, rank = match
entity_items.append(
{
"type": "entity",
"title": name or hash_value,
"matched_field": matched_field,
"matched_value": matched_value,
"entity_name": name or hash_value,
"entity_hash": hash_value,
"appearance_count": int(row.get("appearance_count", 0) or 0),
"_rank": rank,
}
)
relation_items: List[Dict[str, Any]] = []
seen_relation_keys: set[str] = set()
for row in relation_rows:
subject = str(row.get("subject", "") or "").strip()
predicate = str(row.get("predicate", "") or "").strip()
obj = str(row.get("object", "") or "").strip()
relation_hash = str(row.get("hash", "") or "").strip()
match = self._pick_graph_search_match(
[
("subject", subject),
("object", obj),
("predicate", predicate),
("hash", relation_hash),
],
normalized_query,
)
if match is None:
continue
dedupe_key = relation_hash or f"{subject.lower()}|{predicate.lower()}|{obj.lower()}"
if dedupe_key in seen_relation_keys:
continue
seen_relation_keys.add(dedupe_key)
matched_field, matched_value, rank = match
relation_items.append(
{
"type": "relation",
"title": self._format_relation_text(subject, predicate, obj),
"matched_field": matched_field,
"matched_value": matched_value,
"subject": subject,
"predicate": predicate,
"object": obj,
"relation_hash": relation_hash,
"confidence": float(row.get("confidence", 0.0) or 0.0),
"created_at": float(row.get("created_at", 0.0) or 0.0),
"_rank": rank,
}
)
items = entity_items + relation_items
items.sort(
key=lambda item: (
int(item["_rank"]) if item.get("_rank") is not None else 99,
0 if str(item.get("type", "") or "") == "entity" else 1,
-int(item.get("appearance_count", 0) or 0)
if str(item.get("type", "") or "") == "entity"
else -float(item.get("confidence", 0.0) or 0.0),
0.0 if str(item.get("type", "") or "") == "entity" else -float(item.get("created_at", 0.0) or 0.0),
str(item.get("entity_name", item.get("subject", "")) or "").lower(),
str(item.get("predicate", "") or "").lower(),
str(item.get("object", "") or "").lower(),
str(item.get("entity_hash", item.get("relation_hash", "")) or "").lower(),
)
)
normalized_items: List[Dict[str, Any]] = []
for item in items[:safe_limit]:
normalized = dict(item)
normalized.pop("_rank", None)
normalized_items.append(normalized)
return {
"success": True,
"query": token,
"limit": safe_limit,
"count": len(normalized_items),
"items": normalized_items,
}
@staticmethod
def _dedupe_strings(values: Iterable[Any]) -> List[str]:
deduped: List[str] = []

View File

@@ -277,6 +277,7 @@ class EpisodeSegmentationService:
model_config, model_label = self._resolve_model_config()
if model_config is None:
raise RuntimeError("episode segmentation model unavailable")
task_name = llm_api.resolve_task_name_from_model_config(model_config, preferred_task_name=model_label)
prompt = self._build_prompt(
source=source,
@@ -284,11 +285,17 @@ class EpisodeSegmentationService:
window_end=window_end,
paragraphs=paragraphs,
)
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.EpisodeSegmentation",
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="A_Memorix.EpisodeSegmentation",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
raise RuntimeError("llm_generate_failed")

View File

@@ -1306,6 +1306,7 @@ class RetrievalTuningManager:
model_cfg = await self._select_llm_model()
if model_cfg is None:
raise RuntimeError("no_llm_model")
task_name = llm_api.resolve_task_name_from_model_config(model_cfg)
retry = self._llm_retry_cfg()
max_attempts = int(retry["max_attempts"])
@@ -1316,11 +1317,17 @@ class RetrievalTuningManager:
last_error: Optional[Exception] = None
for idx in range(max_attempts):
try:
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_cfg,
request_type=request_type,
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type=request_type,
prompt=prompt,
temperature=getattr(model_cfg, "temperature", None),
max_tokens=getattr(model_cfg, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success:
raise RuntimeError("llm_generation_failed")
text = str(response or "").strip()

View File

@@ -280,15 +280,22 @@ class SummaryImporter:
model_config_to_use = self._resolve_summary_model_config()
if model_config_to_use is None:
return False, "未找到可用的总结模型配置"
task_name_to_use = llm_api.resolve_task_name_from_model_config(model_config_to_use)
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config_to_use,
request_type="A_Memorix.ChatSummarization"
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name_to_use,
request_type="A_Memorix.ChatSummarization",
prompt=prompt,
temperature=getattr(model_config_to_use, "temperature", None),
max_tokens=getattr(model_config_to_use, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
return False, "LLM 生成总结失败"

View File

@@ -3165,14 +3165,21 @@ class ImportTaskManager:
async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]:
cfg = self._llm_retry_config()
retries = int(cfg["retries"])
task_name = llm_api.resolve_task_name_from_model_config(model_config)
last_error: Optional[Exception] = None
for attempt in range(retries + 1):
try:
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="A_Memorix.WebImport",
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="A_Memorix.WebImport",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if not success or not response:
raise RuntimeError("LLM 生成失败")

View File

@@ -88,11 +88,59 @@ class AMemorixHostService:
def get_config(self) -> Dict[str, Any]:
return dict(self._read_config())
def get_raw_config(self) -> str:
def _build_default_config(self) -> Dict[str, Any]:
schema = self.get_config_schema()
sections = schema.get("sections") if isinstance(schema, dict) else None
if not isinstance(sections, dict):
return {}
defaults: Dict[str, Any] = {}
for section_name, section_payload in sections.items():
if not isinstance(section_payload, dict):
continue
fields = section_payload.get("fields")
if not isinstance(fields, dict):
continue
section_parts = [part for part in str(section_name or "").split(".") if part]
if not section_parts:
continue
section_target: Dict[str, Any] = defaults
for part in section_parts:
nested = section_target.get(part)
if not isinstance(nested, dict):
nested = {}
section_target[part] = nested
section_target = nested
for field_name, field_payload in fields.items():
if not isinstance(field_payload, dict) or "default" not in field_payload:
continue
section_target[str(field_name)] = _to_builtin_data(field_payload.get("default"))
return defaults
def get_raw_config_with_meta(self) -> Dict[str, Any]:
path = self.get_config_path()
if not path.exists():
return ""
return path.read_text(encoding="utf-8")
if path.exists():
return {
"config": path.read_text(encoding="utf-8"),
"exists": True,
"using_default": False,
}
default_config = self._build_default_config()
default_raw = tomlkit.dumps(default_config) if default_config else ""
return {
"config": default_raw,
"exists": False,
"using_default": True,
}
def get_raw_config(self) -> str:
payload = self.get_raw_config_with_meta()
return str(payload.get("config", "") or "")
async def update_raw_config(self, raw_config: str) -> Dict[str, Any]:
tomlkit.loads(raw_config)
@@ -231,16 +279,18 @@ class AMemorixHostService:
path = self.get_config_path()
if not path.exists():
self._config_cache = {}
return {}
defaults = self._build_default_config()
self._config_cache = defaults
return dict(defaults)
try:
with path.open("r", encoding="utf-8") as handle:
loaded = tomlkit.load(handle)
except Exception as exc:
logger.warning("读取 A_Memorix 配置失败 %s: %s", path, exc)
self._config_cache = {}
return {}
defaults = self._build_default_config()
self._config_cache = defaults
return dict(defaults)
self._config_cache = _to_builtin_data(loaded) if isinstance(loaded, dict) else {}
return dict(self._config_cache)

View File

@@ -560,11 +560,18 @@ Chat paragraph:
)
async def _llm_call(self, prompt: str, model_config: Any) -> Dict:
"""Generic LLM Caller"""
success, response, _, _ = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="Script.ProcessKnowledge"
task_name = llm_api.resolve_task_name_from_model_config(model_config)
result = await llm_api.generate(
llm_api.LLMServiceRequest(
task_name=task_name,
request_type="Script.ProcessKnowledge",
prompt=prompt,
temperature=getattr(model_config, "temperature", None),
max_tokens=getattr(model_config, "max_tokens", None),
)
)
success = bool(result.success)
response = str(result.completion.response or "")
if success:
txt = response.strip()
if "```" in txt:

View File

@@ -230,6 +230,61 @@ def resolve_task_name(task_name: str = "") -> str:
return normalized_task_name
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
该方法用于兼容仍以 `model_config` 传参的调用方:
1. 优先使用显式给出的 `preferred_task_name`
2. 其次匹配对象同一性;
3. 再尝试按 `model_list` 精确匹配;
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
Args:
model_config: 旧调用方持有的任务配置对象。
preferred_task_name: 候选任务名(可选)。
Returns:
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
Raises:
RuntimeError: 当前没有可用模型配置。
ValueError: 无法解析任何可用任务名时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_preferred = str(preferred_task_name or "").strip()
if normalized_preferred and normalized_preferred in models:
return normalized_preferred
for task_name, task_cfg in models.items():
if task_cfg is model_config:
return task_name
requested_model_list_raw = getattr(model_config, "model_list", [])
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
if requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if candidate_list == requested_model_list:
return task_name
for requested_model in requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if requested_model in candidate_list:
logger.info(
"[LLMService] 旧版 model_config 未命中任务配置,"
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
)
return task_name
if normalized_preferred:
logger.warning(f"[LLMService] 无法映射旧版 model_config回退默认任务: preferred={normalized_preferred}")
return resolve_task_name("")
def _normalize_role(role_name: str) -> RoleType:
"""将原始角色字符串转换为内部角色枚举。

View File

@@ -168,6 +168,10 @@ async def _graph_get(limit: int) -> dict:
return await memory_service.graph_admin(action="get_graph", limit=limit)
async def _graph_search(query: str, limit: int) -> dict:
return await memory_service.graph_admin(action="search", query=query, limit=limit)
async def _graph_get_node_detail(
node_id: str,
*,
@@ -390,9 +394,20 @@ async def _memory_config_get() -> dict:
async def _memory_config_get_raw() -> dict:
raw_payload_getter = getattr(a_memorix_host_service, "get_raw_config_with_meta", None)
if callable(raw_payload_getter):
raw_payload = raw_payload_getter()
else:
raw_payload = {
"config": a_memorix_host_service.get_raw_config(),
"exists": bool(a_memorix_host_service.get_config_path().exists()),
"using_default": False,
}
return {
"success": True,
"config": a_memorix_host_service.get_raw_config(),
"config": str(raw_payload.get("config", "") or ""),
"exists": bool(raw_payload.get("exists", False)),
"using_default": bool(raw_payload.get("using_default", False)),
"path": str(a_memorix_host_service.get_config_path()),
}
@@ -649,6 +664,14 @@ async def get_memory_graph(limit: int = Query(200, ge=1, le=5000)):
return await _graph_get(limit)
@router.get("/graph/search")
async def search_memory_graph(
query: str = Query(..., min_length=1),
limit: int = Query(50, ge=1, le=200),
):
return await _graph_search(query, limit)
@router.get("/graph/node-detail")
async def get_memory_graph_node_detail(
node_id: str = Query(..., min_length=1),