diff --git a/scripts/reproduce_maisaka_memory_growth.py b/scripts/reproduce_maisaka_memory_growth.py new file mode 100644 index 00000000..7ab4e413 --- /dev/null +++ b/scripts/reproduce_maisaka_memory_growth.py @@ -0,0 +1,202 @@ +""" +使用方法 +python .\scripts\reproduce_maisaka_memory_growth.py --messages 100 --batch-size 50 --sessions 100 --session-batch-size 50 --payload-size 1024 --session-payload-size 1024 + +""" + + +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import argparse +import asyncio +import gc +import inspect +import sys +import time + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +class PayloadMessage: + + __slots__ = ("message_id", "timestamp", "payload") + + def __init__(self, message_id: str, payload_size: int) -> None: + self.message_id = message_id + self.timestamp = SimpleNamespace(timestamp=lambda: time.time()) + self.payload = bytearray(payload_size) + + +@dataclass +class FakeRuntime: + payload: bytearray + stopped: bool = False + + async def stop(self) -> None: + self.stopped = True + + def prune_runtime_caches(self) -> None: + return None + + +def _bool_cn(value: bool) -> str: + return "是" if value else "否" + + +def _build_maisaka_runtime_stub(max_cache_size: int) -> Any: + from src.learners.expression_learner import ExpressionLearner + from src.maisaka.runtime import MaisakaHeartFlowChatting + + runtime = object.__new__(MaisakaHeartFlowChatting) + runtime._running = False + runtime._last_message_received_at = 0.0 + runtime._last_processed_index = 0 + runtime._message_cache_max_size = max_cache_size + runtime.message_cache = [] + runtime._message_received_at_by_id = {} + runtime._source_messages_by_id = {} + runtime.history_loop = [] + runtime.log_prefix = "[memory-repro]" + runtime._expression_learner = ExpressionLearner("memory-repro-session") + runtime._enable_expression_learning = False + runtime._enable_jargon_learning = False + runtime._agent_state = "idle" + runtime._STATE_RUNNING = "running" + runtime._reply_latency_measurement_started_at = None + runtime._message_debounce_required = False + runtime._update_message_trigger_state = lambda message: None + runtime._is_reply_effect_tracking_enabled = lambda: False + return runtime + + +def _mark_expression_learner_consumed(runtime: Any) -> None: + learner = runtime._expression_learner + if hasattr(learner, "set_processed_message_cache_index"): + learner.set_processed_message_cache_index(len(runtime.message_cache)) + return + if hasattr(learner, "_last_processed_index"): + learner._last_processed_index = len(runtime.message_cache) + + +async def _maybe_call_runtime_prune(runtime: Any) -> bool: + prune_runtime_caches = getattr(runtime, "prune_runtime_caches", None) + if not callable(prune_runtime_caches): + return False + + result = prune_runtime_caches() + if inspect.isawaitable(result): + await result + return True + + +async def probe_maisaka_message_cache(args: argparse.Namespace) -> bool: + from src.maisaka.runtime import MaisakaHeartFlowChatting + + runtime = _build_maisaka_runtime_stub(args.max_cache_size) + print("[Maisaka 消息缓存]") + print("批次,累计注册消息数,缓存消息数,原始消息映射数,已处理下标,MB") + + for index in range(args.messages): + message = PayloadMessage(f"m{index}", args.payload_size) + await MaisakaHeartFlowChatting.register_message(runtime, message) + if (index + 1) % args.batch_size != 0: + continue + + MaisakaHeartFlowChatting._collect_pending_messages(runtime) + if args.call_prune: + _mark_expression_learner_consumed(runtime) + await _maybe_call_runtime_prune(runtime) + gc.collect() + + retained_payload = sum(len(message.payload) for message in runtime.message_cache) + print( + f"{(index + 1) // args.batch_size}," + f"{index + 1}," + f"{len(runtime.message_cache)}," + f"{len(runtime._source_messages_by_id)}," + f"{runtime._last_processed_index}," + f"{retained_payload / 1024 / 1024:.2f}" + ) + + issue_observed = len(runtime.message_cache) > args.max_cache_size + print(f"是否观察到无界增长={_bool_cn(issue_observed)}") + return issue_observed + + +async def probe_heartflow_session_registry(args: argparse.Namespace) -> bool: + from src.chat.heart_flow.heartflow_manager import HeartflowManager + + manager = HeartflowManager() + print("\n[Heartflow 会话注册表]") + print("批次,累计会话数,注册表长度,锁数量,MB") + + for index in range(args.sessions): + session_id = f"session-{index}" + runtime = FakeRuntime(bytearray(args.session_payload_size)) + manager.heartflow_chat_list[session_id] = runtime + manager._chat_create_locks[session_id] = None + if hasattr(manager, "_last_access_at"): + manager._last_access_at[session_id] = 100.0 + + if (index + 1) % args.session_batch_size != 0: + continue + + retained_payload = sum(len(runtime.payload) for runtime in manager.heartflow_chat_list.values()) + print( + f"{(index + 1) // args.session_batch_size}," + f"{index + 1}," + f"{len(manager.heartflow_chat_list)}," + f"{len(manager._chat_create_locks)}," + f"{retained_payload / 1024 / 1024:.2f}" + ) + + if args.call_cleanup: + cleanup_idle_chats = getattr(manager, "cleanup_idle_chats", None) + if callable(cleanup_idle_chats): + cleanup_now = 100.0 + (6 * 60 * 60) + 1.0 + await cleanup_idle_chats(now=cleanup_now) + + issue_observed = len(manager.heartflow_chat_list) == args.sessions + print(f"剩余会话数={len(manager.heartflow_chat_list)}") + print(f"是否观察到会话未释放={_bool_cn(issue_observed)}") + return issue_observed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="复现") + parser.add_argument("--messages", type=int, default=6000) + parser.add_argument("--batch-size", type=int, default=1000) + parser.add_argument("--payload-size", type=int, default=16 * 1024) + parser.add_argument("--max-cache-size", type=int, default=200) + parser.add_argument("--sessions", type=int, default=3000) + parser.add_argument("--session-batch-size", type=int, default=500) + parser.add_argument("--session-payload-size", type=int, default=32 * 1024) + parser.add_argument( + "--call-prune", + action="store_true", + help="每个消息批次结束后,如运行时提供裁剪hook则主动调用", + ) + parser.add_argument( + "--call-cleanup", + action="store_true", + help="填充会话注册表后,如 HeartflowManager 提供清理方法则主动调用", + ) + return parser.parse_args() + + +async def main() -> None: + args = parse_args() + message_issue = await probe_maisaka_message_cache(args) + session_issue = await probe_heartflow_session_registry(args) + print("\n[汇总]") + print(f"消息缓存问题是否复现={_bool_cn(message_issue)}") + print(f"会话注册表问题是否复现={_bool_cn(session_issue)}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 20ea48a2..3bbc6ec3 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,5 +1,4 @@ import asyncio -import time import traceback from typing import Dict @@ -10,10 +9,6 @@ from src.maisaka.runtime import MaisakaHeartFlowChatting logger = get_logger("heartflow") -HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS = 6 * 60 * 60 -HEARTFLOW_RUNTIME_MAX_SESSIONS = 512 -HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS = 5 * 60 - class HeartflowManager: """管理 session 级别的 Maisaka 心流实例。""" @@ -21,91 +16,25 @@ class HeartflowManager: def __init__(self) -> None: self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} self._chat_create_locks: Dict[str, asyncio.Lock] = {} - self._last_access_at: Dict[str, float] = {} - self._last_cleanup_at = 0.0 - - async def _stop_runtime(self, session_id: str, chat: MaisakaHeartFlowChatting) -> None: - try: - await chat.stop() - except Exception as exc: - logger.warning(f"清理心流聊天 {session_id} 时停止 runtime 失败: {exc}", exc_info=True) - - def _prune_runtime_caches(self, chat: MaisakaHeartFlowChatting) -> None: - prune_runtime_caches = getattr(chat, "prune_runtime_caches", None) - if callable(prune_runtime_caches): - prune_runtime_caches() - - async def cleanup_idle_chats(self, *, now: float | None = None, exclude_session_ids: set[str] | None = None) -> None: - """清理长期空闲或超过容量的 Maisaka runtime。""" - current_time = time.time() if now is None else now - excluded_session_ids = exclude_session_ids or set() - expire_before = current_time - HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS - session_ids_to_remove = [ - session_id - for session_id, accessed_at in self._last_access_at.items() - if accessed_at < expire_before and session_id not in excluded_session_ids - ] - - active_count_after_idle = len(self.heartflow_chat_list) - len(set(session_ids_to_remove)) - if active_count_after_idle > HEARTFLOW_RUNTIME_MAX_SESSIONS: - overflow_count = active_count_after_idle - HEARTFLOW_RUNTIME_MAX_SESSIONS - active_session_ids = [ - session_id - for session_id in self.heartflow_chat_list - if session_id not in session_ids_to_remove - and session_id not in excluded_session_ids - ] - active_session_ids.sort(key=lambda session_id: self._last_access_at.get(session_id, 0.0)) - session_ids_to_remove.extend(active_session_ids[:overflow_count]) - - self._last_cleanup_at = current_time - removed_count = 0 - for session_id in dict.fromkeys(session_ids_to_remove): - chat = self.heartflow_chat_list.pop(session_id, None) - self._chat_create_locks.pop(session_id, None) - self._last_access_at.pop(session_id, None) - if chat is None: - continue - await self._stop_runtime(session_id, chat) - removed_count += 1 - if removed_count > 0: - logger.info(f"已清理空闲心流聊天: 数量={removed_count} 剩余={len(self.heartflow_chat_list)}") - - async def _cleanup_idle_chats_if_due(self, *, now: float, exclude_session_id: str) -> None: - cleanup_due = now - self._last_cleanup_at >= HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS - capacity_exceeded = len(self.heartflow_chat_list) >= HEARTFLOW_RUNTIME_MAX_SESSIONS - if not cleanup_due and not capacity_exceeded: - return - await self.cleanup_idle_chats(now=now, exclude_session_ids={exclude_session_id}) async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: """获取或创建指定会话对应的 Maisaka runtime。""" try: - current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): - self._last_access_at[session_id] = current_time - self._prune_runtime_caches(chat) - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) async with create_lock: - current_time = time.time() if chat := self.heartflow_chat_list.get(session_id): - self._last_access_at[session_id] = current_time - self._prune_runtime_caches(chat) - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) return chat chat_session = chat_manager.get_session_by_session_id(session_id) if not chat_session: raise ValueError(f"未找到 session_id={session_id} 对应的聊天流") - await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id) new_chat = MaisakaHeartFlowChatting(session_id=session_id) await new_chat.start() self.heartflow_chat_list[session_id] = new_chat - self._last_access_at[session_id] = current_time return new_chat except Exception as exc: logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 244ceb4d..afe4b063 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -941,7 +941,6 @@ class MaisakaReasoningEngine: """结束并记录一轮 Maisaka 思考循环。""" cycle_detail.end_time = time.time() self._runtime.history_loop.append(cycle_detail) - self._runtime.prune_runtime_caches() self._post_process_chat_history_after_cycle() timer_strings = [ diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index 84834370..bdb39e1f 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -56,9 +56,6 @@ from .tool_provider import MaisakaBuiltinToolProvider logger = get_logger("maisaka_runtime") MAX_INTERNAL_ROUNDS = 10 -MESSAGE_CACHE_MIN_RETAINED = 200 -MESSAGE_CACHE_CONTEXT_MULTIPLIER = 4 -HISTORY_LOOP_MAX_RETAINED = 256 class MaisakaHeartFlowChatting: @@ -85,7 +82,7 @@ class MaisakaHeartFlowChatting: self._chat_history: list[LLMContextMessage] = [] self.history_loop: list[CycleDetail] = [] - # Keep recent original messages for batching, tools, and later learning. + # Keep all original messages for batching and later learning. self.message_cache: list[SessionMessage] = [] self._last_processed_index = 0 self._internal_turn_queue: asyncio.Queue[Literal["message", "timeout"]] = asyncio.Queue() @@ -114,10 +111,6 @@ class MaisakaHeartFlowChatting: else global_config.chat.max_private_context_size ) self._max_context_size = max(1, int(configured_context_size)) - self._message_cache_max_size = max( - MESSAGE_CACHE_MIN_RETAINED, - self._max_context_size * MESSAGE_CACHE_CONTEXT_MULTIPLIER, - ) self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP self._pending_wait_tool_call_id: Optional[str] = None self._force_next_timing_continue = False @@ -947,54 +940,6 @@ class MaisakaHeartFlowChatting: def _has_pending_messages(self) -> bool: return self._last_processed_index < len(self.message_cache) - def _get_expression_learner_processed_index(self) -> int: - learner_index = getattr(self._expression_learner, "_last_processed_index", self._last_processed_index) - try: - return max(0, int(learner_index)) - except (TypeError, ValueError): - return self._last_processed_index - - def _adjust_expression_learner_processed_index(self, removed_count: int) -> None: - if not hasattr(self._expression_learner, "_last_processed_index"): - return - learner_index = self._get_expression_learner_processed_index() - setattr(self._expression_learner, "_last_processed_index", max(0, learner_index - removed_count)) - - def _prune_processed_message_cache(self) -> None: - """Trim old processed messages while preserving pending and learning windows.""" - max_size = max(1, int(getattr(self, "_message_cache_max_size", MESSAGE_CACHE_MIN_RETAINED))) - overflow_count = len(self.message_cache) - max_size - if overflow_count <= 0: - return - - processed_boundary = self._last_processed_index - removable_count = min(overflow_count, processed_boundary) - if removable_count <= 0: - return - - removed_messages = self.message_cache[:removable_count] - removed_message_ids = {message.message_id for message in removed_messages} - del self.message_cache[:removable_count] - self._last_processed_index = max(0, self._last_processed_index - removable_count) - self._adjust_expression_learner_processed_index(removable_count) - retained_message_ids = {message.message_id for message in self.message_cache} - - for message_id in removed_message_ids: - if message_id not in retained_message_ids: - self._message_received_at_by_id.pop(message_id, None) - self._source_messages_by_id.pop(message_id, None) - - logger.debug( - f"{self.log_prefix} 已裁剪 Maisaka 消息缓存: " - f"移除={removable_count} 剩余={len(self.message_cache)} 上限={max_size}" - ) - - def prune_runtime_caches(self) -> None: - """Apply bounded retention to runtime-only in-memory histories.""" - self._prune_processed_message_cache() - if len(self.history_loop) > HISTORY_LOOP_MAX_RETAINED: - del self.history_loop[: len(self.history_loop) - HISTORY_LOOP_MAX_RETAINED] - def _schedule_message_turn(self) -> None: """为当前待处理消息安排一次内部 turn。""" if self._agent_state == self._STATE_WAIT: