From 5fc6551f57bef3a661a4268fd7dd2b537b6f87ec Mon Sep 17 00:00:00 2001 From: DawnARC Date: Sat, 9 May 2026 00:40:56 +0800 Subject: [PATCH] =?UTF-8?q?fix(A=5Fmemorix):=E4=BD=BF=E7=94=A8=20coerce=5F?= =?UTF-8?q?metadata=5Fdict=20=E5=A4=84=E7=90=86=20metadata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入 coerce_metadata_dict 工具函数,统一化字典的转义:对于 Mapping 类型输入返回字典,否则返回空字典,并替换掉代码中的 dict(...) 转换。 更新了 dual_path.py、sdk_memory_kernel.py、person_profile_service.py 和 search_execution_service.py 中的调用点和导入,以规范化metadata,避免 metadata 为 None 或非字典类型时出现错误。 --- src/A_memorix/core/retrieval/dual_path.py | 5 +-- .../core/runtime/sdk_memory_kernel.py | 5 +-- src/A_memorix/core/utils/metadata.py | 11 +++++++ .../core/utils/person_profile_service.py | 31 +++++++++---------- .../core/utils/search_execution_service.py | 23 +++++++------- 5 files changed, 43 insertions(+), 32 deletions(-) create mode 100644 src/A_memorix/core/utils/metadata.py diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index 245bafea..996c02f8 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -16,6 +16,7 @@ from src.common.logger import get_logger from ..storage import VectorStore, GraphStore, MetadataStore from ..embedding import EmbeddingAPIAdapter from ..utils.matcher import AhoCorasick +from ..utils.metadata import coerce_metadata_dict from ..utils.time_parser import format_timestamp from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .pagerank import PersonalizedPageRank, PageRankConfig @@ -482,7 +483,7 @@ class DualPathRetriever: score=float(item.score), result_type=item.result_type, source=item.source, - metadata=dict(item.metadata or {}), + metadata=coerce_metadata_dict(item.metadata), ) def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: @@ -762,7 +763,7 @@ class DualPathRetriever: existing = self._clone_retrieval_result(item) merged[item.hash_value] = existing else: - for key, value in dict(item.metadata or {}).items(): + for key, value in coerce_metadata_dict(item.metadata).items(): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): existing.metadata[key] = value source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index dfbbbd77..26ff503a 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_service import EpisodeService from ..utils.hash import compute_hash, normalize_text +from ..utils.metadata import coerce_metadata_dict from ..utils.person_profile_service import PersonProfileService from ..utils.relation_write_service import RelationWriteService from ..utils.retrieval_tuning_manager import RetrievalTuningManager @@ -871,7 +872,7 @@ class SDKMemoryKernel: "detail": "chat_filtered", } - summary_meta = dict(metadata or {}) + summary_meta = coerce_metadata_dict(metadata) summary_meta.setdefault("kind", "chat_summary") if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): result = await self.summarize_chat_stream( @@ -961,7 +962,7 @@ class SDKMemoryKernel: participant_tokens = self._tokens(participants) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) source = self._build_source(source_type, chat_id, person_tokens) - paragraph_meta = dict(metadata or {}) + paragraph_meta = coerce_metadata_dict(metadata) paragraph_meta.update( { "external_id": external_token, diff --git a/src/A_memorix/core/utils/metadata.py b/src/A_memorix/core/utils/metadata.py new file mode 100644 index 00000000..5a1dafc1 --- /dev/null +++ b/src/A_memorix/core/utils/metadata.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Dict + + +def coerce_metadata_dict(value: Any) -> Dict[str, Any]: + """返回字典,如果输入值不是字典则返回空字典。""" + if isinstance(value, Mapping): + return dict(value) + return {} diff --git a/src/A_memorix/core/utils/person_profile_service.py b/src/A_memorix/core/utils/person_profile_service.py index 081eaa66..eec531e0 100644 --- a/src/A_memorix/core/utils/person_profile_service.py +++ b/src/A_memorix/core/utils/person_profile_service.py @@ -27,6 +27,7 @@ from ..retrieval import ( GraphRelationRecallConfig, ) from ..storage import MetadataStore, GraphStore, VectorStore +from .metadata import coerce_metadata_dict logger = get_logger("A_Memorix.PersonProfileService") @@ -334,7 +335,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(relation.get("metadata")) + metadata = coerce_metadata_dict(relation.get("metadata")) if str(metadata.get("person_id", "") or "").strip() == pid: return True if pid in self._list_tokens(metadata.get("person_ids")): @@ -350,7 +351,7 @@ class PersonProfileService: payload = { "hash": source_paragraph, "source": str(paragraph.get("source", "") or ""), - "metadata": self._metadata_dict(paragraph.get("metadata")), + "metadata": coerce_metadata_dict(paragraph.get("metadata")), } return self._is_evidence_bound_to_person(payload, person_id=pid) @@ -385,15 +386,11 @@ class PersonProfileService: "score": 1.1, "content": content[:220], "source": str(row.get("source", "") or source), - "metadata": dict(row.get("metadata", {}) or {}), + "metadata": coerce_metadata_dict(row.get("metadata")), } ) return self._filter_stale_paragraph_evidence(evidence) - @staticmethod - def _metadata_dict(value: Any) -> Dict[str, Any]: - return dict(value) if isinstance(value, dict) else {} - @staticmethod def _list_tokens(value: Any) -> List[str]: if value is None: @@ -414,7 +411,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(item.get("metadata")) + metadata = coerce_metadata_dict(item.get("metadata")) source = str(item.get("source", "") or metadata.get("source", "") or "").strip() if source == f"person_fact:{pid}": return True @@ -440,15 +437,15 @@ class PersonProfileService: paragraph_hash: str, metadata: Dict[str, Any], ) -> Tuple[Dict[str, Any], str]: - merged = self._metadata_dict(metadata) + merged = coerce_metadata_dict(metadata) source = str(merged.get("source", "") or "").strip() try: paragraph = self.metadata_store.get_paragraph(paragraph_hash) except Exception: paragraph = None if isinstance(paragraph, dict): - paragraph_metadata = paragraph.get("metadata", {}) or {} - if isinstance(paragraph_metadata, dict): + paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata")) + if paragraph_metadata: merged = {**paragraph_metadata, **merged} source = source or str(paragraph.get("source", "") or "").strip() source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) @@ -538,7 +535,7 @@ class PersonProfileService: "score": 0.0, "content": str(para.get("content", ""))[:180], "source": str(para.get("source", "") or ""), - "metadata": self._metadata_dict(para.get("metadata")), + "metadata": coerce_metadata_dict(para.get("metadata")), } ) if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): @@ -562,18 +559,18 @@ class PersonProfileService: logger.warning(f"向量证据召回失败: alias={alias}, err={e}") continue for item in results: - h = str(getattr(item, "hash_value", "") or "") + h = str(item.hash_value or "") if not h or h in seen_hash: continue metadata, source = self._enrich_paragraph_evidence_metadata( h, - self._metadata_dict(getattr(item, "metadata", {})), + coerce_metadata_dict(item.metadata), ) payload = { "hash": h, - "type": str(getattr(item, "result_type", "")), - "score": float(getattr(item, "score", 0.0) or 0.0), - "content": str(getattr(item, "content", "") or "")[:220], + "type": str(item.result_type), + "score": float(item.score or 0.0), + "content": str(item.content or "")[:220], "source": source, "metadata": metadata, } diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index ace051e9..f6f6b145 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger -from ..retrieval import TemporalQueryOptions +from ..retrieval import RetrievalResult, TemporalQueryOptions +from .metadata import coerce_metadata_dict from .search_postprocess import ( apply_safe_content_dedup, maybe_apply_smart_path_fallback, @@ -286,8 +287,8 @@ class SearchExecutionService: ) async def _executor() -> Dict[str, Any]: - original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) + original_ppr = bool(retriever.config.enable_ppr) + retriever.config.enable_ppr = bool(request.enable_ppr) started_at = time.time() try: retrieved = await retriever.retrieve( @@ -321,7 +322,7 @@ class SearchExecutionService: relation_hashes = [ item.hash_value for item in retrieved - if getattr(item, "result_type", "") == "relation" + if item.result_type == "relation" ] if relation_hashes: await plugin_instance.reinforce_access(relation_hashes) @@ -380,7 +381,7 @@ class SearchExecutionService: elapsed_ms = (time.time() - started_at) * 1000.0 return {"results": retrieved, "elapsed_ms": elapsed_ms} finally: - setattr(retriever.config, "enable_ppr", original_ppr) + retriever.config.enable_ppr = original_ppr dedup_hit = False try: @@ -421,18 +422,18 @@ class SearchExecutionService: ) @staticmethod - def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: + def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]: serialized: List[Dict[str, Any]] = [] for item in results: - metadata = dict(getattr(item, "metadata", {}) or {}) + metadata = coerce_metadata_dict(item.metadata) if "time_meta" not in metadata: metadata["time_meta"] = {} serialized.append( { - "hash": getattr(item, "hash_value", ""), - "type": getattr(item, "result_type", ""), - "score": float(getattr(item, "score", 0.0)), - "content": getattr(item, "content", ""), + "hash": item.hash_value, + "type": item.result_type, + "score": float(item.score), + "content": item.content, "metadata": metadata, } )