Merge pull request #1657 from A-Dawn/dev

fix:优化人物画像来源限制,并改进模型选择器
This commit is contained in:
Dawn ARC
2026-05-08 15:08:06 +08:00
committed by GitHub
3 changed files with 175 additions and 48 deletions

View File

@@ -2608,7 +2608,9 @@ class MetadataStore:
Returns:
段落列表
"""
return self.query("SELECT * FROM paragraphs WHERE source = ?", (source,))
cursor = self._conn.cursor()
cursor.execute("SELECT * FROM paragraphs WHERE source = ?", (source,))
return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()]
def get_all_sources(self) -> List[Dict[str, Any]]:
"""

View File

@@ -283,7 +283,13 @@ class PersonProfileService:
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
return aliases, primary_name, memory_traits
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
def _collect_relation_evidence(
self,
aliases: List[str],
limit: int = 30,
*,
person_id: str = "",
) -> List[Dict[str, Any]]:
relation_by_hash: Dict[str, Dict[str, Any]] = {}
for alias in aliases:
for rel in self.metadata_store.get_relations(subject=alias, include_inactive=False):
@@ -296,6 +302,12 @@ class PersonProfileService:
relation_by_hash[h] = rel
relations = list(relation_by_hash.values())
if person_id:
relations = [
rel
for rel in relations
if self._is_relation_bound_to_person(rel, person_id=person_id)
]
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
relations = relations[: max(1, int(limit))]
@@ -312,6 +324,38 @@ class PersonProfileService:
)
return edges
def _is_relation_bound_to_person(
self,
relation: Dict[str, Any],
*,
person_id: str,
) -> bool:
pid = str(person_id or "").strip()
if not pid:
return False
metadata = self._metadata_dict(relation.get("metadata"))
if str(metadata.get("person_id", "") or "").strip() == pid:
return True
if pid in self._list_tokens(metadata.get("person_ids")):
return True
source_paragraph = str(relation.get("source_paragraph", "") or "").strip()
if source_paragraph:
try:
paragraph = self.metadata_store.get_paragraph(source_paragraph)
except Exception:
paragraph = None
if isinstance(paragraph, dict):
payload = {
"hash": source_paragraph,
"source": str(paragraph.get("source", "") or ""),
"metadata": self._metadata_dict(paragraph.get("metadata")),
}
return self._is_evidence_bound_to_person(payload, person_id=pid)
return False
def _collect_person_fact_evidence(self, person_id: str, limit: int = 4) -> List[Dict[str, Any]]:
token = str(person_id or "").strip()
if not token:
@@ -346,6 +390,42 @@ class PersonProfileService:
)
return self._filter_stale_paragraph_evidence(evidence)
@staticmethod
def _metadata_dict(value: Any) -> Dict[str, Any]:
return dict(value) if isinstance(value, dict) else {}
@staticmethod
def _list_tokens(value: Any) -> List[str]:
if value is None:
return []
if isinstance(value, (list, tuple, set)):
return [str(item or "").strip() for item in value if str(item or "").strip()]
token = str(value or "").strip()
return [token] if token else []
def _is_evidence_bound_to_person(
self,
item: Dict[str, Any],
*,
person_id: str,
) -> bool:
"""画像证据必须显式绑定到 person_id避免别名全局召回串人。"""
pid = str(person_id or "").strip()
if not pid:
return False
metadata = self._metadata_dict(item.get("metadata"))
source = str(item.get("source", "") or metadata.get("source", "") or "").strip()
if source == f"person_fact:{pid}":
return True
if str(metadata.get("person_id", "") or "").strip() == pid:
return True
if pid in self._list_tokens(metadata.get("person_ids")):
return True
return False
@staticmethod
def _source_type_from_source(source: str) -> str:
token = str(source or "").strip()
@@ -360,7 +440,7 @@ class PersonProfileService:
paragraph_hash: str,
metadata: Dict[str, Any],
) -> Tuple[Dict[str, Any], str]:
merged = dict(metadata or {})
merged = self._metadata_dict(metadata)
source = str(merged.get("source", "") or "").strip()
try:
paragraph = self.metadata_store.get_paragraph(paragraph_hash)
@@ -458,9 +538,11 @@ class PersonProfileService:
"score": 0.0,
"content": str(para.get("content", ""))[:180],
"source": str(para.get("source", "") or ""),
"metadata": dict(para.get("metadata", {}) or {}),
"metadata": self._metadata_dict(para.get("metadata")),
}
)
if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id):
fallback.pop()
return self._filter_stale_paragraph_evidence(fallback[:top_k])
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
@@ -483,21 +565,22 @@ class PersonProfileService:
h = str(getattr(item, "hash_value", "") or "")
if not h or h in seen_hash:
continue
seen_hash.add(h)
metadata, source = self._enrich_paragraph_evidence_metadata(
h,
dict(getattr(item, "metadata", {}) or {}),
)
evidence.append(
{
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"source": source,
"metadata": metadata,
}
self._metadata_dict(getattr(item, "metadata", {})),
)
payload = {
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"source": source,
"metadata": metadata,
}
if not self._is_evidence_bound_to_person(payload, person_id=person_id):
continue
seen_hash.add(h)
evidence.append(payload)
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return self._filter_stale_paragraph_evidence(evidence[:top_k])
@@ -640,7 +723,7 @@ class PersonProfileService:
if not aliases and person_keyword:
aliases = [person_keyword.strip()]
primary_name = person_keyword.strip()
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2), person_id=pid)
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k), person_id=pid)
evidence_ids = [

View File

@@ -16,7 +16,7 @@ import traceback
from src.common.logger import get_logger
from src.services import llm_service as llm_api
from src.services import message_service as message_api
from src.config.config import global_config, model_config as host_model_config
from src.config.config import config_manager, global_config
from src.config.model_configs import TaskConfig
from ..storage import (
@@ -150,36 +150,57 @@ class SummaryImporter:
return True
def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]:
"""标准化 summarization.model_name 配置vNext 仅接受字符串数组)"""
"""标准化 summarization.model_name 配置。"""
if raw_value is None:
return ["auto"]
if isinstance(raw_value, list):
selectors = [str(x).strip() for x in raw_value if str(x).strip()]
return selectors or ["auto"]
if isinstance(raw_value, str):
selector = raw_value.strip()
if selector:
logger.warning("summarization.model_name 建议使用 List[str],当前字符串配置已兼容处理。")
return [selector]
return ["auto"]
raise ValueError(
"summarization.model_name 在 vNext 必须为 List[str]。"
"summarization.model_name 必须为 List[str] 或 str"
" 请执行 scripts/release_vnext_migrate.py migrate。"
)
def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]:
"""
选择总结默认任务,避免错误落到 embedding 任务。
优先级:memory > utils > planner;不再顺延到 replyer 或其他任务
优先级:replyer > utils > planner > tool_use > 其他非 embedding
"""
preferred = ("memory", "utils", "planner")
preferred = ("replyer", "utils", "planner", "tool_use")
for name in preferred:
cfg = available_tasks.get(name)
if cfg and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if name != "embedding" and cfg.model_list:
return name, cfg
for name, cfg in available_tasks.items():
if cfg.model_list:
return name, cfg
return None, None
def _resolve_summary_model_config(self) -> Optional[TaskConfig]:
@staticmethod
def _current_model_dict() -> Dict[str, Any]:
try:
return getattr(config_manager.get_model_config(), "models_dict", {}) or {}
except Exception as exc:
logger.warning(f"读取当前模型字典失败: {exc}")
return {}
def _resolve_summary_model_config(self) -> Optional[Tuple[str, TaskConfig]]:
"""
解析 summarization.model_name 为 TaskConfig。
解析 summarization.model_name 为 (task_name, TaskConfig)
支持:
- "auto"
- "memory"(任务名)
- "replyer"(任务名)
- "some-model-name"(具体模型名)
- ["utils:model1", "utils:model2", "replyer"](数组混合语法)
@@ -192,16 +213,18 @@ class SummaryImporter:
# 避免默认值本身触发类型校验异常。
raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", ["auto"])
selectors = self._normalize_summary_model_selectors(raw_cfg)
_default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks)
selected_models: List[str] = []
base_cfg: Optional[TaskConfig] = None
model_dict = getattr(host_model_config, "models_dict", {})
base_task_name: Optional[str] = None
model_dict = self._current_model_dict()
def _append_models(models: List[str]):
for model_name in models:
if model_name and model_name not in selected_models:
selected_models.append(model_name)
def _find_task_for_model(model_name: str) -> Tuple[Optional[str], Optional[TaskConfig]]:
for task_name, task_cfg in available_tasks.items():
task_models = [str(item).strip() for item in (getattr(task_cfg, "model_list", []) or []) if str(item).strip()]
if model_name in task_models:
return task_name, task_cfg
return None, None
for raw_selector in selectors:
selector = raw_selector.strip()
@@ -210,9 +233,9 @@ class SummaryImporter:
if selector.lower() == "auto":
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
base_task_name = default_task_name
continue
if ":" in selector:
@@ -226,42 +249,60 @@ class SummaryImporter:
if base_cfg is None:
base_cfg = task_cfg
base_task_name = task_name
if not model_name or model_name.lower() == "auto":
_append_models(task_cfg.model_list)
continue
if model_name in model_dict or model_name in task_cfg.model_list:
_append_models([model_name])
if model_name in task_cfg.model_list:
logger.info(
f"总结模型选择器 '{selector}' 已定位到任务 '{task_name}'"
"当前 LLM 服务按任务候选列表执行,不单独覆盖具体模型。"
)
else:
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}'在,已跳过")
logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不在任务 '{task_name}',已跳过")
continue
task_cfg = available_tasks.get(selector)
if task_cfg:
_append_models(task_cfg.model_list)
if base_cfg is None:
base_cfg = task_cfg
base_task_name = selector
continue
if selector in model_dict:
_append_models([selector])
task_name, task_cfg = _find_task_for_model(selector)
if task_name and task_cfg:
if base_cfg is None:
base_cfg = task_cfg
base_task_name = task_name
logger.info(
f"总结模型选择器 '{selector}' 已映射到任务 '{task_name}'"
"当前 LLM 服务按任务候选列表执行,不单独覆盖具体模型。"
)
continue
logger.warning(f"总结模型选择器 '{selector}' 未归属于任何任务,已跳过")
continue
logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过")
if not selected_models:
if base_cfg is None or not base_task_name:
if default_task_cfg:
_append_models(default_task_cfg.model_list)
if base_cfg is None:
base_cfg = default_task_cfg
base_task_name = default_task_name
else:
base_task_name, first_cfg = next(iter(available_tasks.items()))
if base_cfg is None:
base_cfg = first_cfg
if not selected_models:
if base_cfg is None or not base_task_name:
return None
template_cfg = base_cfg or default_task_cfg or TaskConfig()
return TaskConfig(
model_list=selected_models,
template_cfg = base_cfg
task_name_to_use = base_task_name
return task_name_to_use, TaskConfig(
model_list=list(template_cfg.model_list),
max_tokens=template_cfg.max_tokens,
temperature=template_cfg.temperature,
slow_threshold=template_cfg.slow_threshold,
@@ -331,12 +372,13 @@ class SummaryImporter:
chat_history=chat_history_text
)
model_config_to_use = self._resolve_summary_model_config()
if model_config_to_use is None:
resolved_model = self._resolve_summary_model_config()
if resolved_model is None:
return False, "未找到可用的总结模型配置"
task_name_to_use = llm_api.resolve_task_name_from_model_config(model_config_to_use)
task_name_to_use, model_config_to_use = resolved_model
logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}")
logger.info(f"总结模型任务: {task_name_to_use}")
logger.info(f"总结模型候选列表: {model_config_to_use.model_list}")
result = await llm_api.generate(