fix(A_memorix):使用 coerce_metadata_dict 处理 metadata

引入 coerce_metadata_dict 工具函数,统一化字典的转义:对于 Mapping 类型输入返回字典,否则返回空字典,并替换掉代码中的 dict(...) 转换。
更新了 dual_path.py、sdk_memory_kernel.py、person_profile_service.py 和 search_execution_service.py 中的调用点和导入,以规范化metadata,避免 metadata 为 None 或非字典类型时出现错误。
This commit is contained in:
DawnARC
2026-05-09 00:40:56 +08:00
parent 531e1428de
commit 5fc6551f57
5 changed files with 43 additions and 32 deletions

View File

@@ -16,6 +16,7 @@ from src.common.logger import get_logger
from ..storage import VectorStore, GraphStore, MetadataStore from ..storage import VectorStore, GraphStore, MetadataStore
from ..embedding import EmbeddingAPIAdapter from ..embedding import EmbeddingAPIAdapter
from ..utils.matcher import AhoCorasick from ..utils.matcher import AhoCorasick
from ..utils.metadata import coerce_metadata_dict
from ..utils.time_parser import format_timestamp from ..utils.time_parser import format_timestamp
from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService
from .pagerank import PersonalizedPageRank, PageRankConfig from .pagerank import PersonalizedPageRank, PageRankConfig
@@ -482,7 +483,7 @@ class DualPathRetriever:
score=float(item.score), score=float(item.score),
result_type=item.result_type, result_type=item.result_type,
source=item.source, source=item.source,
metadata=dict(item.metadata or {}), metadata=coerce_metadata_dict(item.metadata),
) )
def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]:
@@ -762,7 +763,7 @@ class DualPathRetriever:
existing = self._clone_retrieval_result(item) existing = self._clone_retrieval_result(item)
merged[item.hash_value] = existing merged[item.hash_value] = existing
else: else:
for key, value in dict(item.metadata or {}).items(): for key, value in coerce_metadata_dict(item.metadata).items():
if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []):
existing.metadata[key] = value existing.metadata[key] = value
source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search")

View File

@@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_segmentation_service import EpisodeSegmentationService
from ..utils.episode_service import EpisodeService from ..utils.episode_service import EpisodeService
from ..utils.hash import compute_hash, normalize_text from ..utils.hash import compute_hash, normalize_text
from ..utils.metadata import coerce_metadata_dict
from ..utils.person_profile_service import PersonProfileService from ..utils.person_profile_service import PersonProfileService
from ..utils.relation_write_service import RelationWriteService from ..utils.relation_write_service import RelationWriteService
from ..utils.retrieval_tuning_manager import RetrievalTuningManager from ..utils.retrieval_tuning_manager import RetrievalTuningManager
@@ -871,7 +872,7 @@ class SDKMemoryKernel:
"detail": "chat_filtered", "detail": "chat_filtered",
} }
summary_meta = dict(metadata or {}) summary_meta = coerce_metadata_dict(metadata)
summary_meta.setdefault("kind", "chat_summary") summary_meta.setdefault("kind", "chat_summary")
if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)):
result = await self.summarize_chat_stream( result = await self.summarize_chat_stream(
@@ -961,7 +962,7 @@ class SDKMemoryKernel:
participant_tokens = self._tokens(participants) participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens) source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {}) paragraph_meta = coerce_metadata_dict(metadata)
paragraph_meta.update( paragraph_meta.update(
{ {
"external_id": external_token, "external_id": external_token,

View File

@@ -0,0 +1,11 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Dict
def coerce_metadata_dict(value: Any) -> Dict[str, Any]:
"""返回字典,如果输入值不是字典则返回空字典。"""
if isinstance(value, Mapping):
return dict(value)
return {}

View File

@@ -27,6 +27,7 @@ from ..retrieval import (
GraphRelationRecallConfig, GraphRelationRecallConfig,
) )
from ..storage import MetadataStore, GraphStore, VectorStore from ..storage import MetadataStore, GraphStore, VectorStore
from .metadata import coerce_metadata_dict
logger = get_logger("A_Memorix.PersonProfileService") logger = get_logger("A_Memorix.PersonProfileService")
@@ -334,7 +335,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(relation.get("metadata")) metadata = coerce_metadata_dict(relation.get("metadata"))
if str(metadata.get("person_id", "") or "").strip() == pid: if str(metadata.get("person_id", "") or "").strip() == pid:
return True return True
if pid in self._list_tokens(metadata.get("person_ids")): if pid in self._list_tokens(metadata.get("person_ids")):
@@ -350,7 +351,7 @@ class PersonProfileService:
payload = { payload = {
"hash": source_paragraph, "hash": source_paragraph,
"source": str(paragraph.get("source", "") or ""), "source": str(paragraph.get("source", "") or ""),
"metadata": self._metadata_dict(paragraph.get("metadata")), "metadata": coerce_metadata_dict(paragraph.get("metadata")),
} }
return self._is_evidence_bound_to_person(payload, person_id=pid) return self._is_evidence_bound_to_person(payload, person_id=pid)
@@ -385,15 +386,11 @@ class PersonProfileService:
"score": 1.1, "score": 1.1,
"content": content[:220], "content": content[:220],
"source": str(row.get("source", "") or source), "source": str(row.get("source", "") or source),
"metadata": dict(row.get("metadata", {}) or {}), "metadata": coerce_metadata_dict(row.get("metadata")),
} }
) )
return self._filter_stale_paragraph_evidence(evidence) return self._filter_stale_paragraph_evidence(evidence)
@staticmethod
def _metadata_dict(value: Any) -> Dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
@staticmethod @staticmethod
def _list_tokens(value: Any) -> List[str]: def _list_tokens(value: Any) -> List[str]:
if value is None: if value is None:
@@ -414,7 +411,7 @@ class PersonProfileService:
if not pid: if not pid:
return False return False
metadata = self._metadata_dict(item.get("metadata")) metadata = coerce_metadata_dict(item.get("metadata"))
source = str(item.get("source", "") or metadata.get("source", "") or "").strip() source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
if source == f"person_fact:{pid}": if source == f"person_fact:{pid}":
return True return True
@@ -440,15 +437,15 @@ class PersonProfileService:
paragraph_hash: str, paragraph_hash: str,
metadata: Dict[str, Any], metadata: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]: ) -> Tuple[Dict[str, Any], str]:
merged = self._metadata_dict(metadata) merged = coerce_metadata_dict(metadata)
source = str(merged.get("source", "") or "").strip() source = str(merged.get("source", "") or "").strip()
try: try:
paragraph = self.metadata_store.get_paragraph(paragraph_hash) paragraph = self.metadata_store.get_paragraph(paragraph_hash)
except Exception: except Exception:
paragraph = None paragraph = None
if isinstance(paragraph, dict): if isinstance(paragraph, dict):
paragraph_metadata = paragraph.get("metadata", {}) or {} paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata"))
if isinstance(paragraph_metadata, dict): if paragraph_metadata:
merged = {**paragraph_metadata, **merged} merged = {**paragraph_metadata, **merged}
source = source or str(paragraph.get("source", "") or "").strip() source = source or str(paragraph.get("source", "") or "").strip()
source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source)
@@ -538,7 +535,7 @@ class PersonProfileService:
"score": 0.0, "score": 0.0,
"content": str(para.get("content", ""))[:180], "content": str(para.get("content", ""))[:180],
"source": str(para.get("source", "") or ""), "source": str(para.get("source", "") or ""),
"metadata": self._metadata_dict(para.get("metadata")), "metadata": coerce_metadata_dict(para.get("metadata")),
} }
) )
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
@@ -562,18 +559,18 @@ class PersonProfileService:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}") logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue continue
for item in results: for item in results:
h = str(getattr(item, "hash_value", "") or "") h = str(item.hash_value or "")
if not h or h in seen_hash: if not h or h in seen_hash:
continue continue
metadata, source = self._enrich_paragraph_evidence_metadata( metadata, source = self._enrich_paragraph_evidence_metadata(
h, h,
self._metadata_dict(getattr(item, "metadata", {})), coerce_metadata_dict(item.metadata),
) )
payload = { payload = {
"hash": h, "hash": h,
"type": str(getattr(item, "result_type", "")), "type": str(item.result_type),
"score": float(getattr(item, "score", 0.0) or 0.0), "score": float(item.score or 0.0),
"content": str(getattr(item, "content", "") or "")[:220], "content": str(item.content or "")[:220],
"source": source, "source": source,
"metadata": metadata, "metadata": metadata,
} }

View File

@@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from ..retrieval import TemporalQueryOptions from ..retrieval import RetrievalResult, TemporalQueryOptions
from .metadata import coerce_metadata_dict
from .search_postprocess import ( from .search_postprocess import (
apply_safe_content_dedup, apply_safe_content_dedup,
maybe_apply_smart_path_fallback, maybe_apply_smart_path_fallback,
@@ -286,8 +287,8 @@ class SearchExecutionService:
) )
async def _executor() -> Dict[str, Any]: async def _executor() -> Dict[str, Any]:
original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) original_ppr = bool(retriever.config.enable_ppr)
setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) retriever.config.enable_ppr = bool(request.enable_ppr)
started_at = time.time() started_at = time.time()
try: try:
retrieved = await retriever.retrieve( retrieved = await retriever.retrieve(
@@ -321,7 +322,7 @@ class SearchExecutionService:
relation_hashes = [ relation_hashes = [
item.hash_value item.hash_value
for item in retrieved for item in retrieved
if getattr(item, "result_type", "") == "relation" if item.result_type == "relation"
] ]
if relation_hashes: if relation_hashes:
await plugin_instance.reinforce_access(relation_hashes) await plugin_instance.reinforce_access(relation_hashes)
@@ -380,7 +381,7 @@ class SearchExecutionService:
elapsed_ms = (time.time() - started_at) * 1000.0 elapsed_ms = (time.time() - started_at) * 1000.0
return {"results": retrieved, "elapsed_ms": elapsed_ms} return {"results": retrieved, "elapsed_ms": elapsed_ms}
finally: finally:
setattr(retriever.config, "enable_ppr", original_ppr) retriever.config.enable_ppr = original_ppr
dedup_hit = False dedup_hit = False
try: try:
@@ -421,18 +422,18 @@ class SearchExecutionService:
) )
@staticmethod @staticmethod
def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]:
serialized: List[Dict[str, Any]] = [] serialized: List[Dict[str, Any]] = []
for item in results: for item in results:
metadata = dict(getattr(item, "metadata", {}) or {}) metadata = coerce_metadata_dict(item.metadata)
if "time_meta" not in metadata: if "time_meta" not in metadata:
metadata["time_meta"] = {} metadata["time_meta"] = {}
serialized.append( serialized.append(
{ {
"hash": getattr(item, "hash_value", ""), "hash": item.hash_value,
"type": getattr(item, "result_type", ""), "type": item.result_type,
"score": float(getattr(item, "score", 0.0)), "score": float(item.score),
"content": getattr(item, "content", ""), "content": item.content,
"metadata": metadata, "metadata": metadata,
} }
) )