From 43db0aa9abf228c7901861f4ea133033c3079a9a Mon Sep 17 00:00:00 2001 From: anderwer Date: Mon, 16 Mar 2026 17:59:08 +0800 Subject: [PATCH 1/8] fix: validate gemini provider tests with query api key --- src/webui/routers/model.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/webui/routers/model.py b/src/webui/routers/model.py index 2f67aca5..f8e0fcc0 100644 --- a/src/webui/routers/model.py +++ b/src/webui/routers/model.py @@ -252,6 +252,7 @@ async def get_models_by_url( async def test_provider_connection( base_url: str = Query(..., description="提供商的基础 URL"), api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"), + client_type: str = Query("openai", description="客户端类型 (openai | gemini)"), ): """ 测试提供商连接状态 @@ -315,13 +316,19 @@ async def test_provider_connection( try: start_time = time.time() async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - } + headers = {"Content-Type": "application/json"} + params = {} + + if client_type == "gemini": + # Gemini 使用 URL 参数传递 API Key + params["key"] = api_key + else: + # OpenAI 兼容格式使用 Authorization 头 + headers["Authorization"] = f"Bearer {api_key}" + # 尝试获取模型列表 models_url = f"{base_url}/models" - response = await client.get(models_url, headers=headers) + response = await client.get(models_url, headers=headers, params=params) if response.status_code == 200: result["api_key_valid"] = True @@ -364,9 +371,14 @@ async def test_provider_connection_by_name( base_url = provider.get("base_url", "") api_key = provider.get("api_key", "") + client_type = provider.get("client_type", "openai") if not base_url: raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") # 调用测试接口 - return await test_provider_connection(base_url=base_url, api_key=api_key or None) + return await test_provider_connection( + base_url=base_url, + api_key=api_key if api_key else None, + client_type=client_type, + ) From 78415e89c1076a9319c12c1f7d3be22390fa7a74 Mon Sep 17 00:00:00 2001 From: anderwer Date: Thu, 26 Mar 2026 08:58:49 +0800 Subject: [PATCH 2/8] test(webui): cover gemini provider connection auth --- pytests/webui/test_model_routes.py | 187 +++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 pytests/webui/test_model_routes.py diff --git a/pytests/webui/test_model_routes.py b/pytests/webui/test_model_routes.py new file mode 100644 index 00000000..0e05ad87 --- /dev/null +++ b/pytests/webui/test_model_routes.py @@ -0,0 +1,187 @@ +"""模型路由测试 + +验证 Gemini 提供商连接测试会使用查询参数传递 API Key, +并且不会回退到 OpenAI 兼容接口使用的 Bearer 认证方式。 +""" + +import importlib +import sys +from types import ModuleType +from typing import Any + +import pytest + + +def load_model_routes(monkeypatch: pytest.MonkeyPatch): + """在导入路由前 stub 配置与认证依赖模块,避免测试时触发真实初始化。""" + config_module = ModuleType("src.config.config") + config_module.__dict__["CONFIG_DIR"] = "." + monkeypatch.setitem(sys.modules, "src.config.config", config_module) + + dependencies_module = ModuleType("src.webui.dependencies") + + async def require_auth(): + return "test-token" + + dependencies_module.__dict__["require_auth"] = require_auth + monkeypatch.setitem(sys.modules, "src.webui.dependencies", dependencies_module) + + sys.modules.pop("src.webui.routers.model", None) + return importlib.import_module("src.webui.routers.model") + + +class FakeResponse: + """简化版 HTTP 响应对象。""" + + def __init__(self, status_code: int): + self.status_code = status_code + + +def build_async_client_factory( + responses: list[FakeResponse], + calls: list[dict[str, Any]], +): + """构造一个可记录请求参数的 AsyncClient 替身。""" + + response_iter = iter(responses) + + class FakeAsyncClient: + def __init__(self, *args: Any, **kwargs: Any): + self.args = args + self.kwargs = kwargs + + async def __aenter__(self) -> "FakeAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> bool: + return False + + async def get( + self, + url: str, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> FakeResponse: + calls.append( + { + "url": url, + "headers": headers or {}, + "params": params or {}, + } + ) + return next(response_iter) + + return FakeAsyncClient + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_query_api_key_for_gemini( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Gemini 连接测试应通过查询参数传递 API Key。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://generativelanguage.googleapis.com/v1beta", + api_key="valid-gemini-key", + client_type="gemini", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + network_call = calls[0] + validation_call = calls[1] + + assert network_call["url"] == "https://generativelanguage.googleapis.com/v1beta" + assert network_call["headers"] == {} + assert network_call["params"] == {} + + assert validation_call["url"] == "https://generativelanguage.googleapis.com/v1beta/models" + assert validation_call["params"] == {"key": "valid-gemini-key"} + assert validation_call["headers"] == {"Content-Type": "application/json"} + assert "Authorization" not in validation_call["headers"] + + +@pytest.mark.asyncio +async def test_test_provider_connection_uses_bearer_auth_for_openai_compatible( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """非 Gemini 提供商连接测试应继续使用 Bearer 认证。""" + model_routes = load_model_routes(monkeypatch) + calls: list[dict[str, Any]] = [] + fake_client_class = build_async_client_factory( + responses=[FakeResponse(200), FakeResponse(200)], + calls=calls, + ) + monkeypatch.setattr(model_routes.httpx, "AsyncClient", fake_client_class) + + result = await model_routes.test_provider_connection( + base_url="https://example.com/v1", + api_key="valid-openai-key", + client_type="openai", + ) + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert len(calls) == 2 + + validation_call = calls[1] + + assert validation_call["url"] == "https://example.com/v1/models" + assert validation_call["params"] == {} + assert validation_call["headers"]["Content-Type"] == "application/json" + assert validation_call["headers"]["Authorization"] == "Bearer valid-openai-key" + + +@pytest.mark.asyncio +async def test_test_provider_connection_by_name_forwards_provider_client_type( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +) -> None: + """按提供商名称测试连接时,应透传配置中的 client_type。""" + model_routes = load_model_routes(monkeypatch) + config_path = tmp_path / "model_config.toml" + config_path.write_text( + """ +[[api_providers]] +name = "Gemini" +base_url = "https://generativelanguage.googleapis.com/v1beta" +api_key = "valid-gemini-key" +client_type = "gemini" +""".strip(), + encoding="utf-8", + ) + + monkeypatch.setattr(model_routes, "CONFIG_DIR", str(tmp_path)) + + captured_kwargs: dict[str, Any] = {} + + async def fake_test_provider_connection(**kwargs: Any) -> dict[str, Any]: + captured_kwargs.update(kwargs) + return { + "network_ok": True, + "api_key_valid": True, + "latency_ms": 12.34, + "error": None, + "http_status": 200, + } + + monkeypatch.setattr(model_routes, "test_provider_connection", fake_test_provider_connection) + + result = await model_routes.test_provider_connection_by_name(provider_name="Gemini") + + assert result["network_ok"] is True + assert result["api_key_valid"] is True + assert captured_kwargs == { + "base_url": "https://generativelanguage.googleapis.com/v1beta", + "api_key": "valid-gemini-key", + "client_type": "gemini", + } \ No newline at end of file From c78125e6d44e77ec6d1686d19743dd12097fdaf8 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Fri, 8 May 2026 17:09:34 +0800 Subject: [PATCH 3/8] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E6=97=A0?= =?UTF-8?q?=E8=BE=B9=E7=95=8C=E5=86=85=E5=AD=98=E5=A2=9E=E9=95=BF=EF=BC=8C?= =?UTF-8?q?=E5=AF=B9=E7=BC=93=E5=AD=98=E5=86=85=E5=AE=B9=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E8=A3=81=E5=88=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/heart_flow/heartflow_manager.py | 71 ++++++++++++++++++++++++ src/maisaka/reasoning_engine.py | 1 + src/maisaka/runtime.py | 57 ++++++++++++++++++- 3 files changed, 128 insertions(+), 1 deletion(-) diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 3bbc6ec3..20ea48a2 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,4 +1,5 @@ import asyncio +import time import traceback from typing import Dict @@ -9,6 +10,10 @@ from src.maisaka.runtime import MaisakaHeartFlowChatting logger = get_logger("heartflow") +HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS = 6 * 60 * 60 +HEARTFLOW_RUNTIME_MAX_SESSIONS = 512 +HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS = 5 * 60 + class HeartflowManager: """管理 session 级别的 Maisaka 心流实例。""" @@ -16,25 +21,91 @@ class HeartflowManager: def __init__(self) -> None: self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} self._chat_create_locks: Dict[str, asyncio.Lock] = {} + self._last_access_at: Dict[str, float] = {} + self._last_cleanup_at = 0.0 + + async def _stop_runtime(self, session_id: str, chat: MaisakaHeartFlowChatting) -> None: + try: + await chat.stop() + except Exception as exc: + logger.warning(f"清理心流聊天 {session_id} 时停止 runtime 失败: {exc}", exc_info=True) + + def _prune_runtime_caches(self, chat: MaisakaHeartFlowChatting) -> None: + prune_runtime_caches = getattr(chat, "prune_runtime_caches", None) + if callable(prune_runtime_caches): + prune_runtime_caches() + + async def cleanup_idle_chats(self, *, now: float | None = None, exclude_session_ids: set[str] | None = None) -> None: + """清理长期空闲或超过容量的 Maisaka runtime。""" + current_time = time.time() if now is None else now + excluded_session_ids = exclude_session_ids or set() + expire_before = current_time - HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS + session_ids_to_remove = [ + session_id + for session_id, accessed_at in self._last_access_at.items() + if accessed_at < expire_before and session_id not in excluded_session_ids + ] + + active_count_after_idle = len(self.heartflow_chat_list) - len(set(session_ids_to_remove)) + if active_count_after_idle > HEARTFLOW_RUNTIME_MAX_SESSIONS: + overflow_count = active_count_after_idle - HEARTFLOW_RUNTIME_MAX_SESSIONS + active_session_ids = [ + session_id + for session_id in self.heartflow_chat_list + if session_id not in session_ids_to_remove + and session_id not in excluded_session_ids + ] + active_session_ids.sort(key=lambda session_id: self._last_access_at.get(session_id, 0.0)) + session_ids_to_remove.extend(active_session_ids[:overflow_count]) + + self._last_cleanup_at = current_time + removed_count = 0 + for session_id in dict.fromkeys(session_ids_to_remove): + chat = self.heartflow_chat_list.pop(session_id, None) + self._chat_create_locks.pop(session_id, None) + self._last_access_at.pop(session_id, None) + if chat is None: + continue + await self._stop_runtime(session_id, chat) + removed_count += 1 + if removed_count > 0: + logger.info(f"已清理空闲心流聊天: 数量={removed_count} 剩余={len(self.heartflow_chat_list)}") + + async def _cleanup_idle_chats_if_due(self, *, now: float, exclude_session_id: str) -> None: + cleanup_due = now - self._last_cleanup_at >= HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS + capacity_exceeded = len(self.heartflow_chat_list) >= HEARTFLOW_RUNTIME_MAX_SESSIONS + if not cleanup_due and not capacity_exceeded: + return + await self.cleanup_idle_chats(now=now, exclude_session_ids={exclude_session_id}) async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: """获取或创建指定会话对应的 Maisaka runtime。""" try: + current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): + self._last_access_at[session_id] = current_time + self._prune_runtime_caches(chat) + await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) async with create_lock: + current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): + self._last_access_at[session_id] = current_time + self._prune_runtime_caches(chat) + await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat chat_session = chat_manager.get_session_by_session_id(session_id) if not chat_session: raise ValueError(f"未找到 session_id={session_id} 对应的聊天流") + await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) new_chat = MaisakaHeartFlowChatting(session_id=session_id) await new_chat.start() self.heartflow_chat_list[session_id] = new_chat + self._last_access_at[session_id] = current_time return new_chat except Exception as exc: logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index afe4b063..244ceb4d 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -941,6 +941,7 @@ class MaisakaReasoningEngine: """结束并记录一轮 Maisaka 思考循环。""" cycle_detail.end_time = time.time() self._runtime.history_loop.append(cycle_detail) + self._runtime.prune_runtime_caches() self._post_process_chat_history_after_cycle() timer_strings = [ diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index bdb39e1f..84834370 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -56,6 +56,9 @@ from .tool_provider import MaisakaBuiltinToolProvider logger = get_logger("maisaka_runtime") MAX_INTERNAL_ROUNDS = 10 +MESSAGE_CACHE_MIN_RETAINED = 200 +MESSAGE_CACHE_CONTEXT_MULTIPLIER = 4 +HISTORY_LOOP_MAX_RETAINED = 256 class MaisakaHeartFlowChatting: @@ -82,7 +85,7 @@ class MaisakaHeartFlowChatting: self._chat_history: list[LLMContextMessage] = [] self.history_loop: list[CycleDetail] = [] - # Keep all original messages for batching and later learning. + # Keep recent original messages for batching, tools, and later learning. self.message_cache: list[SessionMessage] = [] self._last_processed_index = 0 self._internal_turn_queue: asyncio.Queue[Literal["message", "timeout"]] = asyncio.Queue() @@ -111,6 +114,10 @@ class MaisakaHeartFlowChatting: else global_config.chat.max_private_context_size ) self._max_context_size = max(1, int(configured_context_size)) + self._message_cache_max_size = max( + MESSAGE_CACHE_MIN_RETAINED, + self._max_context_size * MESSAGE_CACHE_CONTEXT_MULTIPLIER, + ) self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._pending_wait_tool_call_id: Optional[str] = None self._force_next_timing_continue = False @@ -940,6 +947,54 @@ class MaisakaHeartFlowChatting: def _has_pending_messages(self) -> bool: return self._last_processed_index < len(self.message_cache) + def _get_expression_learner_processed_index(self) -> int: + learner_index = getattr(self._expression_learner, "_last_processed_index", self._last_processed_index) + try: + return max(0, int(learner_index)) + except (TypeError, ValueError): + return self._last_processed_index + + def _adjust_expression_learner_processed_index(self, removed_count: int) -> None: + if not hasattr(self._expression_learner, "_last_processed_index"): + return + learner_index = self._get_expression_learner_processed_index() + setattr(self._expression_learner, "_last_processed_index", max(0, learner_index - removed_count)) + + def _prune_processed_message_cache(self) -> None: + """Trim old processed messages while preserving pending and learning windows.""" + max_size = max(1, int(getattr(self, "_message_cache_max_size", MESSAGE_CACHE_MIN_RETAINED))) + overflow_count = len(self.message_cache) - max_size + if overflow_count <= 0: + return + + processed_boundary = self._last_processed_index + removable_count = min(overflow_count, processed_boundary) + if removable_count <= 0: + return + + removed_messages = self.message_cache[:removable_count] + removed_message_ids = {message.message_id for message in removed_messages} + del self.message_cache[:removable_count] + self._last_processed_index = max(0, self._last_processed_index - removable_count) + self._adjust_expression_learner_processed_index(removable_count) + retained_message_ids = {message.message_id for message in self.message_cache} + + for message_id in removed_message_ids: + if message_id not in retained_message_ids: + self._message_received_at_by_id.pop(message_id, None) + self._source_messages_by_id.pop(message_id, None) + + logger.debug( + f"{self.log_prefix} 已裁剪 Maisaka 消息缓存: " + f"移除={removable_count} 剩余={len(self.message_cache)} 上限={max_size}" + ) + + def prune_runtime_caches(self) -> None: + """Apply bounded retention to runtime-only in-memory histories.""" + self._prune_processed_message_cache() + if len(self.history_loop) > HISTORY_LOOP_MAX_RETAINED: + del self.history_loop[: len(self.history_loop) - HISTORY_LOOP_MAX_RETAINED] + def _schedule_message_turn(self) -> None: """为当前待处理消息安排一次内部 turn。""" if self._agent_state == self._STATE_WAIT: From 81bf3fee1a45a38d45bb7d2863f8eda7411a70de Mon Sep 17 00:00:00 2001 From: DawnARC Date: Fri, 8 May 2026 21:32:16 +0800 Subject: [PATCH 4/8] =?UTF-8?q?revert:=E5=9B=9E=E9=80=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/reproduce_maisaka_memory_growth.py | 202 +++++++++++++++++++++ src/chat/heart_flow/heartflow_manager.py | 71 -------- src/maisaka/reasoning_engine.py | 1 - src/maisaka/runtime.py | 57 +----- 4 files changed, 203 insertions(+), 128 deletions(-) create mode 100644 scripts/reproduce_maisaka_memory_growth.py diff --git a/scripts/reproduce_maisaka_memory_growth.py b/scripts/reproduce_maisaka_memory_growth.py new file mode 100644 index 00000000..7ab4e413 --- /dev/null +++ b/scripts/reproduce_maisaka_memory_growth.py @@ -0,0 +1,202 @@ +""" +使用方法 +python .\scripts\reproduce_maisaka_memory_growth.py --messages 100 --batch-size 50 --sessions 100 --session-batch-size 50 --payload-size 1024 --session-payload-size 1024 + +""" + + +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import argparse +import asyncio +import gc +import inspect +import sys +import time + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +class PayloadMessage: + + __slots__ = ("message_id", "timestamp", "payload") + + def __init__(self, message_id: str, payload_size: int) -> None: + self.message_id = message_id + self.timestamp = SimpleNamespace(timestamp=lambda: time.time()) + self.payload = bytearray(payload_size) + + +@dataclass +class FakeRuntime: + payload: bytearray + stopped: bool = False + + async def stop(self) -> None: + self.stopped = True + + def prune_runtime_caches(self) -> None: + return None + + +def _bool_cn(value: bool) -> str: + return "是" if value else "否" + + +def _build_maisaka_runtime_stub(max_cache_size: int) -> Any: + from src.learners.expression_learner import ExpressionLearner + from src.maisaka.runtime import MaisakaHeartFlowChatting + + runtime = object.__new__(MaisakaHeartFlowChatting) + runtime._running = False + runtime._last_message_received_at = 0.0 + runtime._last_processed_index = 0 + runtime._message_cache_max_size = max_cache_size + runtime.message_cache = [] + runtime._message_received_at_by_id = {} + runtime._source_messages_by_id = {} + runtime.history_loop = [] + runtime.log_prefix = "[memory-repro]" + runtime._expression_learner = ExpressionLearner("memory-repro-session") + runtime._enable_expression_learning = False + runtime._enable_jargon_learning = False + runtime._agent_state = "idle" + runtime._STATE_RUNNING = "running" + runtime._reply_latency_measurement_started_at = None + runtime._message_debounce_required = False + runtime._update_message_trigger_state = lambda message: None + runtime._is_reply_effect_tracking_enabled = lambda: False + return runtime + + +def _mark_expression_learner_consumed(runtime: Any) -> None: + learner = runtime._expression_learner + if hasattr(learner, "set_processed_message_cache_index"): + learner.set_processed_message_cache_index(len(runtime.message_cache)) + return + if hasattr(learner, "_last_processed_index"): + learner._last_processed_index = len(runtime.message_cache) + + +async def _maybe_call_runtime_prune(runtime: Any) -> bool: + prune_runtime_caches = getattr(runtime, "prune_runtime_caches", None) + if not callable(prune_runtime_caches): + return False + + result = prune_runtime_caches() + if inspect.isawaitable(result): + await result + return True + + +async def probe_maisaka_message_cache(args: argparse.Namespace) -> bool: + from src.maisaka.runtime import MaisakaHeartFlowChatting + + runtime = _build_maisaka_runtime_stub(args.max_cache_size) + print("[Maisaka 消息缓存]") + print("批次,累计注册消息数,缓存消息数,原始消息映射数,已处理下标,MB") + + for index in range(args.messages): + message = PayloadMessage(f"m{index}", args.payload_size) + await MaisakaHeartFlowChatting.register_message(runtime, message) + if (index + 1) % args.batch_size != 0: + continue + + MaisakaHeartFlowChatting._collect_pending_messages(runtime) + if args.call_prune: + _mark_expression_learner_consumed(runtime) + await _maybe_call_runtime_prune(runtime) + gc.collect() + + retained_payload = sum(len(message.payload) for message in runtime.message_cache) + print( + f"{(index + 1) // args.batch_size}," + f"{index + 1}," + f"{len(runtime.message_cache)}," + f"{len(runtime._source_messages_by_id)}," + f"{runtime._last_processed_index}," + f"{retained_payload / 1024 / 1024:.2f}" + ) + + issue_observed = len(runtime.message_cache) > args.max_cache_size + print(f"是否观察到无界增长={_bool_cn(issue_observed)}") + return issue_observed + + +async def probe_heartflow_session_registry(args: argparse.Namespace) -> bool: + from src.chat.heart_flow.heartflow_manager import HeartflowManager + + manager = HeartflowManager() + print("\n[Heartflow 会话注册表]") + print("批次,累计会话数,注册表长度,锁数量,MB") + + for index in range(args.sessions): + session_id = f"session-{index}" + runtime = FakeRuntime(bytearray(args.session_payload_size)) + manager.heartflow_chat_list[session_id] = runtime + manager._chat_create_locks[session_id] = None + if hasattr(manager, "_last_access_at"): + manager._last_access_at[session_id] = 100.0 + + if (index + 1) % args.session_batch_size != 0: + continue + + retained_payload = sum(len(runtime.payload) for runtime in manager.heartflow_chat_list.values()) + print( + f"{(index + 1) // args.session_batch_size}," + f"{index + 1}," + f"{len(manager.heartflow_chat_list)}," + f"{len(manager._chat_create_locks)}," + f"{retained_payload / 1024 / 1024:.2f}" + ) + + if args.call_cleanup: + cleanup_idle_chats = getattr(manager, "cleanup_idle_chats", None) + if callable(cleanup_idle_chats): + cleanup_now = 100.0 + (6 * 60 * 60) + 1.0 + await cleanup_idle_chats(now=cleanup_now) + + issue_observed = len(manager.heartflow_chat_list) == args.sessions + print(f"剩余会话数={len(manager.heartflow_chat_list)}") + print(f"是否观察到会话未释放={_bool_cn(issue_observed)}") + return issue_observed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="复现") + parser.add_argument("--messages", type=int, default=6000) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--payload-size", type=int, default=16 * 1024) + parser.add_argument("--max-cache-size", type=int, default=200) + parser.add_argument("--sessions", type=int, default=3000) + parser.add_argument("--session-batch-size", type=int, default=500) + parser.add_argument("--session-payload-size", type=int, default=32 * 1024) + parser.add_argument( + "--call-prune", + action="store_true", + help="每个消息批次结束后,如运行时提供裁剪hook则主动调用", + ) + parser.add_argument( + "--call-cleanup", + action="store_true", + help="填充会话注册表后,如 HeartflowManager 提供清理方法则主动调用", + ) + return parser.parse_args() + + +async def main() -> None: + args = parse_args() + message_issue = await probe_maisaka_message_cache(args) + session_issue = await probe_heartflow_session_registry(args) + print("\n[汇总]") + print(f"消息缓存问题是否复现={_bool_cn(message_issue)}") + print(f"会话注册表问题是否复现={_bool_cn(session_issue)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 20ea48a2..3bbc6ec3 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,5 +1,4 @@ import asyncio -import time import traceback from typing import Dict @@ -10,10 +9,6 @@ from src.maisaka.runtime import MaisakaHeartFlowChatting logger = get_logger("heartflow") -HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS = 6 * 60 * 60 -HEARTFLOW_RUNTIME_MAX_SESSIONS = 512 -HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS = 5 * 60 - class HeartflowManager: """管理 session 级别的 Maisaka 心流实例。""" @@ -21,91 +16,25 @@ class HeartflowManager: def __init__(self) -> None: self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} self._chat_create_locks: Dict[str, asyncio.Lock] = {} - self._last_access_at: Dict[str, float] = {} - self._last_cleanup_at = 0.0 - - async def _stop_runtime(self, session_id: str, chat: MaisakaHeartFlowChatting) -> None: - try: - await chat.stop() - except Exception as exc: - logger.warning(f"清理心流聊天 {session_id} 时停止 runtime 失败: {exc}", exc_info=True) - - def _prune_runtime_caches(self, chat: MaisakaHeartFlowChatting) -> None: - prune_runtime_caches = getattr(chat, "prune_runtime_caches", None) - if callable(prune_runtime_caches): - prune_runtime_caches() - - async def cleanup_idle_chats(self, *, now: float | None = None, exclude_session_ids: set[str] | None = None) -> None: - """清理长期空闲或超过容量的 Maisaka runtime。""" - current_time = time.time() if now is None else now - excluded_session_ids = exclude_session_ids or set() - expire_before = current_time - HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS - session_ids_to_remove = [ - session_id - for session_id, accessed_at in self._last_access_at.items() - if accessed_at < expire_before and session_id not in excluded_session_ids - ] - - active_count_after_idle = len(self.heartflow_chat_list) - len(set(session_ids_to_remove)) - if active_count_after_idle > HEARTFLOW_RUNTIME_MAX_SESSIONS: - overflow_count = active_count_after_idle - HEARTFLOW_RUNTIME_MAX_SESSIONS - active_session_ids = [ - session_id - for session_id in self.heartflow_chat_list - if session_id not in session_ids_to_remove - and session_id not in excluded_session_ids - ] - active_session_ids.sort(key=lambda session_id: self._last_access_at.get(session_id, 0.0)) - session_ids_to_remove.extend(active_session_ids[:overflow_count]) - - self._last_cleanup_at = current_time - removed_count = 0 - for session_id in dict.fromkeys(session_ids_to_remove): - chat = self.heartflow_chat_list.pop(session_id, None) - self._chat_create_locks.pop(session_id, None) - self._last_access_at.pop(session_id, None) - if chat is None: - continue - await self._stop_runtime(session_id, chat) - removed_count += 1 - if removed_count > 0: - logger.info(f"已清理空闲心流聊天: 数量={removed_count} 剩余={len(self.heartflow_chat_list)}") - - async def _cleanup_idle_chats_if_due(self, *, now: float, exclude_session_id: str) -> None: - cleanup_due = now - self._last_cleanup_at >= HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS - capacity_exceeded = len(self.heartflow_chat_list) >= HEARTFLOW_RUNTIME_MAX_SESSIONS - if not cleanup_due and not capacity_exceeded: - return - await self.cleanup_idle_chats(now=now, exclude_session_ids={exclude_session_id}) async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: """获取或创建指定会话对应的 Maisaka runtime。""" try: - current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): - self._last_access_at[session_id] = current_time - self._prune_runtime_caches(chat) - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) async with create_lock: - current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): - self._last_access_at[session_id] = current_time - self._prune_runtime_caches(chat) - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat chat_session = chat_manager.get_session_by_session_id(session_id) if not chat_session: raise ValueError(f"未找到 session_id={session_id} 对应的聊天流") - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) new_chat = MaisakaHeartFlowChatting(session_id=session_id) await new_chat.start() self.heartflow_chat_list[session_id] = new_chat - self._last_access_at[session_id] = current_time return new_chat except Exception as exc: logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 244ceb4d..afe4b063 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -941,7 +941,6 @@ class MaisakaReasoningEngine: """结束并记录一轮 Maisaka 思考循环。""" cycle_detail.end_time = time.time() self._runtime.history_loop.append(cycle_detail) - self._runtime.prune_runtime_caches() self._post_process_chat_history_after_cycle() timer_strings = [ diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 84834370..bdb39e1f 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -56,9 +56,6 @@ from .tool_provider import MaisakaBuiltinToolProvider logger = get_logger("maisaka_runtime") MAX_INTERNAL_ROUNDS = 10 -MESSAGE_CACHE_MIN_RETAINED = 200 -MESSAGE_CACHE_CONTEXT_MULTIPLIER = 4 -HISTORY_LOOP_MAX_RETAINED = 256 class MaisakaHeartFlowChatting: @@ -85,7 +82,7 @@ class MaisakaHeartFlowChatting: self._chat_history: list[LLMContextMessage] = [] self.history_loop: list[CycleDetail] = [] - # Keep recent original messages for batching, tools, and later learning. + # Keep all original messages for batching and later learning. self.message_cache: list[SessionMessage] = [] self._last_processed_index = 0 self._internal_turn_queue: asyncio.Queue[Literal["message", "timeout"]] = asyncio.Queue() @@ -114,10 +111,6 @@ class MaisakaHeartFlowChatting: else global_config.chat.max_private_context_size ) self._max_context_size = max(1, int(configured_context_size)) - self._message_cache_max_size = max( - MESSAGE_CACHE_MIN_RETAINED, - self._max_context_size * MESSAGE_CACHE_CONTEXT_MULTIPLIER, - ) self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._pending_wait_tool_call_id: Optional[str] = None self._force_next_timing_continue = False @@ -947,54 +940,6 @@ class MaisakaHeartFlowChatting: def _has_pending_messages(self) -> bool: return self._last_processed_index < len(self.message_cache) - def _get_expression_learner_processed_index(self) -> int: - learner_index = getattr(self._expression_learner, "_last_processed_index", self._last_processed_index) - try: - return max(0, int(learner_index)) - except (TypeError, ValueError): - return self._last_processed_index - - def _adjust_expression_learner_processed_index(self, removed_count: int) -> None: - if not hasattr(self._expression_learner, "_last_processed_index"): - return - learner_index = self._get_expression_learner_processed_index() - setattr(self._expression_learner, "_last_processed_index", max(0, learner_index - removed_count)) - - def _prune_processed_message_cache(self) -> None: - """Trim old processed messages while preserving pending and learning windows.""" - max_size = max(1, int(getattr(self, "_message_cache_max_size", MESSAGE_CACHE_MIN_RETAINED))) - overflow_count = len(self.message_cache) - max_size - if overflow_count <= 0: - return - - processed_boundary = self._last_processed_index - removable_count = min(overflow_count, processed_boundary) - if removable_count <= 0: - return - - removed_messages = self.message_cache[:removable_count] - removed_message_ids = {message.message_id for message in removed_messages} - del self.message_cache[:removable_count] - self._last_processed_index = max(0, self._last_processed_index - removable_count) - self._adjust_expression_learner_processed_index(removable_count) - retained_message_ids = {message.message_id for message in self.message_cache} - - for message_id in removed_message_ids: - if message_id not in retained_message_ids: - self._message_received_at_by_id.pop(message_id, None) - self._source_messages_by_id.pop(message_id, None) - - logger.debug( - f"{self.log_prefix} 已裁剪 Maisaka 消息缓存: " - f"移除={removable_count} 剩余={len(self.message_cache)} 上限={max_size}" - ) - - def prune_runtime_caches(self) -> None: - """Apply bounded retention to runtime-only in-memory histories.""" - self._prune_processed_message_cache() - if len(self.history_loop) > HISTORY_LOOP_MAX_RETAINED: - del self.history_loop[: len(self.history_loop) - HISTORY_LOOP_MAX_RETAINED] - def _schedule_message_turn(self) -> None: """为当前待处理消息安排一次内部 turn。""" if self._agent_state == self._STATE_WAIT: From 531e1428de4b7a6c5b6ee7f60506521b6f3a9b05 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Sat, 9 May 2026 00:33:48 +0800 Subject: [PATCH 5/8] =?UTF-8?q?revert:=E5=9B=9E=E9=80=80=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/reproduce_maisaka_memory_growth.py | 202 --------------------- 1 file changed, 202 deletions(-) delete mode 100644 scripts/reproduce_maisaka_memory_growth.py diff --git a/scripts/reproduce_maisaka_memory_growth.py b/scripts/reproduce_maisaka_memory_growth.py deleted file mode 100644 index 7ab4e413..00000000 --- a/scripts/reproduce_maisaka_memory_growth.py +++ /dev/null @@ -1,202 +0,0 @@ -""" -使用方法 -python .\scripts\reproduce_maisaka_memory_growth.py --messages 100 --batch-size 50 --sessions 100 --session-batch-size 50 --payload-size 1024 --session-payload-size 1024 - -""" - - -from dataclasses import dataclass -from pathlib import Path -from types import SimpleNamespace -from typing import Any - -import argparse -import asyncio -import gc -import inspect -import sys -import time - -REPO_ROOT = Path(__file__).resolve().parents[1] -if str(REPO_ROOT) not in sys.path: - sys.path.insert(0, str(REPO_ROOT)) - - -class PayloadMessage: - - __slots__ = ("message_id", "timestamp", "payload") - - def __init__(self, message_id: str, payload_size: int) -> None: - self.message_id = message_id - self.timestamp = SimpleNamespace(timestamp=lambda: time.time()) - self.payload = bytearray(payload_size) - - -@dataclass -class FakeRuntime: - payload: bytearray - stopped: bool = False - - async def stop(self) -> None: - self.stopped = True - - def prune_runtime_caches(self) -> None: - return None - - -def _bool_cn(value: bool) -> str: - return "是" if value else "否" - - -def _build_maisaka_runtime_stub(max_cache_size: int) -> Any: - from src.learners.expression_learner import ExpressionLearner - from src.maisaka.runtime import MaisakaHeartFlowChatting - - runtime = object.__new__(MaisakaHeartFlowChatting) - runtime._running = False - runtime._last_message_received_at = 0.0 - runtime._last_processed_index = 0 - runtime._message_cache_max_size = max_cache_size - runtime.message_cache = [] - runtime._message_received_at_by_id = {} - runtime._source_messages_by_id = {} - runtime.history_loop = [] - runtime.log_prefix = "[memory-repro]" - runtime._expression_learner = ExpressionLearner("memory-repro-session") - runtime._enable_expression_learning = False - runtime._enable_jargon_learning = False - runtime._agent_state = "idle" - runtime._STATE_RUNNING = "running" - runtime._reply_latency_measurement_started_at = None - runtime._message_debounce_required = False - runtime._update_message_trigger_state = lambda message: None - runtime._is_reply_effect_tracking_enabled = lambda: False - return runtime - - -def _mark_expression_learner_consumed(runtime: Any) -> None: - learner = runtime._expression_learner - if hasattr(learner, "set_processed_message_cache_index"): - learner.set_processed_message_cache_index(len(runtime.message_cache)) - return - if hasattr(learner, "_last_processed_index"): - learner._last_processed_index = len(runtime.message_cache) - - -async def _maybe_call_runtime_prune(runtime: Any) -> bool: - prune_runtime_caches = getattr(runtime, "prune_runtime_caches", None) - if not callable(prune_runtime_caches): - return False - - result = prune_runtime_caches() - if inspect.isawaitable(result): - await result - return True - - -async def probe_maisaka_message_cache(args: argparse.Namespace) -> bool: - from src.maisaka.runtime import MaisakaHeartFlowChatting - - runtime = _build_maisaka_runtime_stub(args.max_cache_size) - print("[Maisaka 消息缓存]") - print("批次,累计注册消息数,缓存消息数,原始消息映射数,已处理下标,MB") - - for index in range(args.messages): - message = PayloadMessage(f"m{index}", args.payload_size) - await MaisakaHeartFlowChatting.register_message(runtime, message) - if (index + 1) % args.batch_size != 0: - continue - - MaisakaHeartFlowChatting._collect_pending_messages(runtime) - if args.call_prune: - _mark_expression_learner_consumed(runtime) - await _maybe_call_runtime_prune(runtime) - gc.collect() - - retained_payload = sum(len(message.payload) for message in runtime.message_cache) - print( - f"{(index + 1) // args.batch_size}," - f"{index + 1}," - f"{len(runtime.message_cache)}," - f"{len(runtime._source_messages_by_id)}," - f"{runtime._last_processed_index}," - f"{retained_payload / 1024 / 1024:.2f}" - ) - - issue_observed = len(runtime.message_cache) > args.max_cache_size - print(f"是否观察到无界增长={_bool_cn(issue_observed)}") - return issue_observed - - -async def probe_heartflow_session_registry(args: argparse.Namespace) -> bool: - from src.chat.heart_flow.heartflow_manager import HeartflowManager - - manager = HeartflowManager() - print("\n[Heartflow 会话注册表]") - print("批次,累计会话数,注册表长度,锁数量,MB") - - for index in range(args.sessions): - session_id = f"session-{index}" - runtime = FakeRuntime(bytearray(args.session_payload_size)) - manager.heartflow_chat_list[session_id] = runtime - manager._chat_create_locks[session_id] = None - if hasattr(manager, "_last_access_at"): - manager._last_access_at[session_id] = 100.0 - - if (index + 1) % args.session_batch_size != 0: - continue - - retained_payload = sum(len(runtime.payload) for runtime in manager.heartflow_chat_list.values()) - print( - f"{(index + 1) // args.session_batch_size}," - f"{index + 1}," - f"{len(manager.heartflow_chat_list)}," - f"{len(manager._chat_create_locks)}," - f"{retained_payload / 1024 / 1024:.2f}" - ) - - if args.call_cleanup: - cleanup_idle_chats = getattr(manager, "cleanup_idle_chats", None) - if callable(cleanup_idle_chats): - cleanup_now = 100.0 + (6 * 60 * 60) + 1.0 - await cleanup_idle_chats(now=cleanup_now) - - issue_observed = len(manager.heartflow_chat_list) == args.sessions - print(f"剩余会话数={len(manager.heartflow_chat_list)}") - print(f"是否观察到会话未释放={_bool_cn(issue_observed)}") - return issue_observed - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="复现") - parser.add_argument("--messages", type=int, default=6000) - parser.add_argument("--batch-size", type=int, default=1000) - parser.add_argument("--payload-size", type=int, default=16 * 1024) - parser.add_argument("--max-cache-size", type=int, default=200) - parser.add_argument("--sessions", type=int, default=3000) - parser.add_argument("--session-batch-size", type=int, default=500) - parser.add_argument("--session-payload-size", type=int, default=32 * 1024) - parser.add_argument( - "--call-prune", - action="store_true", - help="每个消息批次结束后,如运行时提供裁剪hook则主动调用", - ) - parser.add_argument( - "--call-cleanup", - action="store_true", - help="填充会话注册表后,如 HeartflowManager 提供清理方法则主动调用", - ) - return parser.parse_args() - - -async def main() -> None: - args = parse_args() - message_issue = await probe_maisaka_message_cache(args) - session_issue = await probe_heartflow_session_registry(args) - print("\n[汇总]") - print(f"消息缓存问题是否复现={_bool_cn(message_issue)}") - print(f"会话注册表问题是否复现={_bool_cn(session_issue)}") - - -if __name__ == "__main__": - asyncio.run(main()) From 5fc6551f57bef3a661a4268fd7dd2b537b6f87ec Mon Sep 17 00:00:00 2001 From: DawnARC Date: Sat, 9 May 2026 00:40:56 +0800 Subject: [PATCH 6/8] =?UTF-8?q?fix(A=5Fmemorix):=E4=BD=BF=E7=94=A8=20coerc?= =?UTF-8?q?e=5Fmetadata=5Fdict=20=E5=A4=84=E7=90=86=20metadata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入 coerce_metadata_dict 工具函数,统一化字典的转义:对于 Mapping 类型输入返回字典,否则返回空字典,并替换掉代码中的 dict(...) 转换。 更新了 dual_path.py、sdk_memory_kernel.py、person_profile_service.py 和 search_execution_service.py 中的调用点和导入,以规范化metadata,避免 metadata 为 None 或非字典类型时出现错误。 --- src/A_memorix/core/retrieval/dual_path.py | 5 +-- .../core/runtime/sdk_memory_kernel.py | 5 +-- src/A_memorix/core/utils/metadata.py | 11 +++++++ .../core/utils/person_profile_service.py | 31 +++++++++---------- .../core/utils/search_execution_service.py | 23 +++++++------- 5 files changed, 43 insertions(+), 32 deletions(-) create mode 100644 src/A_memorix/core/utils/metadata.py diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index 245bafea..996c02f8 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -16,6 +16,7 @@ from src.common.logger import get_logger from ..storage import VectorStore, GraphStore, MetadataStore from ..embedding import EmbeddingAPIAdapter from ..utils.matcher import AhoCorasick +from ..utils.metadata import coerce_metadata_dict from ..utils.time_parser import format_timestamp from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService from .pagerank import PersonalizedPageRank, PageRankConfig @@ -482,7 +483,7 @@ class DualPathRetriever: score=float(item.score), result_type=item.result_type, source=item.source, - metadata=dict(item.metadata or {}), + metadata=coerce_metadata_dict(item.metadata), ) def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: @@ -762,7 +763,7 @@ class DualPathRetriever: existing = self._clone_retrieval_result(item) merged[item.hash_value] = existing else: - for key, value in dict(item.metadata or {}).items(): + for key, value in coerce_metadata_dict(item.metadata).items(): if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): existing.metadata[key] = value source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index dfbbbd77..26ff503a 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -25,6 +25,7 @@ from ..utils.episode_retrieval_service import EpisodeRetrievalService from ..utils.episode_segmentation_service import EpisodeSegmentationService from ..utils.episode_service import EpisodeService from ..utils.hash import compute_hash, normalize_text +from ..utils.metadata import coerce_metadata_dict from ..utils.person_profile_service import PersonProfileService from ..utils.relation_write_service import RelationWriteService from ..utils.retrieval_tuning_manager import RetrievalTuningManager @@ -871,7 +872,7 @@ class SDKMemoryKernel: "detail": "chat_filtered", } - summary_meta = dict(metadata or {}) + summary_meta = coerce_metadata_dict(metadata) summary_meta.setdefault("kind", "chat_summary") if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): result = await self.summarize_chat_stream( @@ -961,7 +962,7 @@ class SDKMemoryKernel: participant_tokens = self._tokens(participants) entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) source = self._build_source(source_type, chat_id, person_tokens) - paragraph_meta = dict(metadata or {}) + paragraph_meta = coerce_metadata_dict(metadata) paragraph_meta.update( { "external_id": external_token, diff --git a/src/A_memorix/core/utils/metadata.py b/src/A_memorix/core/utils/metadata.py new file mode 100644 index 00000000..5a1dafc1 --- /dev/null +++ b/src/A_memorix/core/utils/metadata.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Dict + + +def coerce_metadata_dict(value: Any) -> Dict[str, Any]: + """返回字典,如果输入值不是字典则返回空字典。""" + if isinstance(value, Mapping): + return dict(value) + return {} diff --git a/src/A_memorix/core/utils/person_profile_service.py b/src/A_memorix/core/utils/person_profile_service.py index 081eaa66..eec531e0 100644 --- a/src/A_memorix/core/utils/person_profile_service.py +++ b/src/A_memorix/core/utils/person_profile_service.py @@ -27,6 +27,7 @@ from ..retrieval import ( GraphRelationRecallConfig, ) from ..storage import MetadataStore, GraphStore, VectorStore +from .metadata import coerce_metadata_dict logger = get_logger("A_Memorix.PersonProfileService") @@ -334,7 +335,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(relation.get("metadata")) + metadata = coerce_metadata_dict(relation.get("metadata")) if str(metadata.get("person_id", "") or "").strip() == pid: return True if pid in self._list_tokens(metadata.get("person_ids")): @@ -350,7 +351,7 @@ class PersonProfileService: payload = { "hash": source_paragraph, "source": str(paragraph.get("source", "") or ""), - "metadata": self._metadata_dict(paragraph.get("metadata")), + "metadata": coerce_metadata_dict(paragraph.get("metadata")), } return self._is_evidence_bound_to_person(payload, person_id=pid) @@ -385,15 +386,11 @@ class PersonProfileService: "score": 1.1, "content": content[:220], "source": str(row.get("source", "") or source), - "metadata": dict(row.get("metadata", {}) or {}), + "metadata": coerce_metadata_dict(row.get("metadata")), } ) return self._filter_stale_paragraph_evidence(evidence) - @staticmethod - def _metadata_dict(value: Any) -> Dict[str, Any]: - return dict(value) if isinstance(value, dict) else {} - @staticmethod def _list_tokens(value: Any) -> List[str]: if value is None: @@ -414,7 +411,7 @@ class PersonProfileService: if not pid: return False - metadata = self._metadata_dict(item.get("metadata")) + metadata = coerce_metadata_dict(item.get("metadata")) source = str(item.get("source", "") or metadata.get("source", "") or "").strip() if source == f"person_fact:{pid}": return True @@ -440,15 +437,15 @@ class PersonProfileService: paragraph_hash: str, metadata: Dict[str, Any], ) -> Tuple[Dict[str, Any], str]: - merged = self._metadata_dict(metadata) + merged = coerce_metadata_dict(metadata) source = str(merged.get("source", "") or "").strip() try: paragraph = self.metadata_store.get_paragraph(paragraph_hash) except Exception: paragraph = None if isinstance(paragraph, dict): - paragraph_metadata = paragraph.get("metadata", {}) or {} - if isinstance(paragraph_metadata, dict): + paragraph_metadata = coerce_metadata_dict(paragraph.get("metadata")) + if paragraph_metadata: merged = {**paragraph_metadata, **merged} source = source or str(paragraph.get("source", "") or "").strip() source_type = str(merged.get("source_type", "") or "").strip() or self._source_type_from_source(source) @@ -538,7 +535,7 @@ class PersonProfileService: "score": 0.0, "content": str(para.get("content", ""))[:180], "source": str(para.get("source", "") or ""), - "metadata": self._metadata_dict(para.get("metadata")), + "metadata": coerce_metadata_dict(para.get("metadata")), } ) if not self._is_evidence_bound_to_person(fallback[-1], person_id=person_id): @@ -562,18 +559,18 @@ class PersonProfileService: logger.warning(f"向量证据召回失败: alias={alias}, err={e}") continue for item in results: - h = str(getattr(item, "hash_value", "") or "") + h = str(item.hash_value or "") if not h or h in seen_hash: continue metadata, source = self._enrich_paragraph_evidence_metadata( h, - self._metadata_dict(getattr(item, "metadata", {})), + coerce_metadata_dict(item.metadata), ) payload = { "hash": h, - "type": str(getattr(item, "result_type", "")), - "score": float(getattr(item, "score", 0.0) or 0.0), - "content": str(getattr(item, "content", "") or "")[:220], + "type": str(item.result_type), + "score": float(item.score or 0.0), + "content": str(item.content or "")[:220], "source": source, "metadata": metadata, } diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index ace051e9..f6f6b145 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -14,7 +14,8 @@ from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger -from ..retrieval import TemporalQueryOptions +from ..retrieval import RetrievalResult, TemporalQueryOptions +from .metadata import coerce_metadata_dict from .search_postprocess import ( apply_safe_content_dedup, maybe_apply_smart_path_fallback, @@ -286,8 +287,8 @@ class SearchExecutionService: ) async def _executor() -> Dict[str, Any]: - original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) + original_ppr = bool(retriever.config.enable_ppr) + retriever.config.enable_ppr = bool(request.enable_ppr) started_at = time.time() try: retrieved = await retriever.retrieve( @@ -321,7 +322,7 @@ class SearchExecutionService: relation_hashes = [ item.hash_value for item in retrieved - if getattr(item, "result_type", "") == "relation" + if item.result_type == "relation" ] if relation_hashes: await plugin_instance.reinforce_access(relation_hashes) @@ -380,7 +381,7 @@ class SearchExecutionService: elapsed_ms = (time.time() - started_at) * 1000.0 return {"results": retrieved, "elapsed_ms": elapsed_ms} finally: - setattr(retriever.config, "enable_ppr", original_ppr) + retriever.config.enable_ppr = original_ppr dedup_hit = False try: @@ -421,18 +422,18 @@ class SearchExecutionService: ) @staticmethod - def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: + def to_serializable_results(results: List[RetrievalResult]) -> List[Dict[str, Any]]: serialized: List[Dict[str, Any]] = [] for item in results: - metadata = dict(getattr(item, "metadata", {}) or {}) + metadata = coerce_metadata_dict(item.metadata) if "time_meta" not in metadata: metadata["time_meta"] = {} serialized.append( { - "hash": getattr(item, "hash_value", ""), - "type": getattr(item, "result_type", ""), - "score": float(getattr(item, "score", 0.0)), - "content": getattr(item, "content", ""), + "hash": item.hash_value, + "type": item.result_type, + "score": float(item.score), + "content": item.content, "metadata": metadata, } ) From 5be22458301fa478f9f7aa538f2bfc316a9f7027 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Sat, 9 May 2026 01:04:04 +0800 Subject: [PATCH 7/8] =?UTF-8?q?fix=EF=BC=9A=E7=BA=A0=E6=AD=A3=E8=AF=AD?= =?UTF-8?q?=E4=B9=89=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/A_memorix/core/utils/search_execution_service.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index f6f6b145..651ff417 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -287,8 +287,11 @@ class SearchExecutionService: ) async def _executor() -> Dict[str, Any]: - original_ppr = bool(retriever.config.enable_ppr) - retriever.config.enable_ppr = bool(request.enable_ppr) + retriever_config = getattr(retriever, "config", None) + has_runtime_ppr_switch = retriever_config is not None and hasattr(retriever_config, "enable_ppr") + original_ppr = bool(retriever_config.enable_ppr) if has_runtime_ppr_switch else None + if has_runtime_ppr_switch: + retriever_config.enable_ppr = bool(request.enable_ppr) started_at = time.time() try: retrieved = await retriever.retrieve( @@ -381,7 +384,8 @@ class SearchExecutionService: elapsed_ms = (time.time() - started_at) * 1000.0 return {"results": retrieved, "elapsed_ms": elapsed_ms} finally: - retriever.config.enable_ppr = original_ppr + if has_runtime_ppr_switch: + retriever_config.enable_ppr = bool(original_ppr) dedup_hit = False try: From 74e686f5c92cc846edff754590fa5355ec151fd5 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Sat, 9 May 2026 01:14:21 +0800 Subject: [PATCH 8/8] =?UTF-8?q?fix:=E7=BA=A0=E6=AD=A3=E8=AF=AD=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/A_memorix/core/utils/search_execution_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/A_memorix/core/utils/search_execution_service.py b/src/A_memorix/core/utils/search_execution_service.py index 651ff417..f05c4820 100644 --- a/src/A_memorix/core/utils/search_execution_service.py +++ b/src/A_memorix/core/utils/search_execution_service.py @@ -325,7 +325,7 @@ class SearchExecutionService: relation_hashes = [ item.hash_value for item in retrieved - if item.result_type == "relation" + if getattr(item, "result_type", "") == "relation" ] if relation_hashes: await plugin_instance.reinforce_access(relation_hashes)