revert:回退修改
This commit is contained in:
202
scripts/reproduce_maisaka_memory_growth.py
Normal file
202
scripts/reproduce_maisaka_memory_growth.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
使用方法
|
||||
python .\scripts\reproduce_maisaka_memory_growth.py --messages 100 --batch-size 50 --sessions 100 --session-batch-size 50 --payload-size 1024 --session-payload-size 1024
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import sys
|
||||
import time
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
|
||||
class PayloadMessage:
|
||||
|
||||
__slots__ = ("message_id", "timestamp", "payload")
|
||||
|
||||
def __init__(self, message_id: str, payload_size: int) -> None:
|
||||
self.message_id = message_id
|
||||
self.timestamp = SimpleNamespace(timestamp=lambda: time.time())
|
||||
self.payload = bytearray(payload_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FakeRuntime:
|
||||
payload: bytearray
|
||||
stopped: bool = False
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.stopped = True
|
||||
|
||||
def prune_runtime_caches(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _bool_cn(value: bool) -> str:
|
||||
return "是" if value else "否"
|
||||
|
||||
|
||||
def _build_maisaka_runtime_stub(max_cache_size: int) -> Any:
|
||||
from src.learners.expression_learner import ExpressionLearner
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime._running = False
|
||||
runtime._last_message_received_at = 0.0
|
||||
runtime._last_processed_index = 0
|
||||
runtime._message_cache_max_size = max_cache_size
|
||||
runtime.message_cache = []
|
||||
runtime._message_received_at_by_id = {}
|
||||
runtime._source_messages_by_id = {}
|
||||
runtime.history_loop = []
|
||||
runtime.log_prefix = "[memory-repro]"
|
||||
runtime._expression_learner = ExpressionLearner("memory-repro-session")
|
||||
runtime._enable_expression_learning = False
|
||||
runtime._enable_jargon_learning = False
|
||||
runtime._agent_state = "idle"
|
||||
runtime._STATE_RUNNING = "running"
|
||||
runtime._reply_latency_measurement_started_at = None
|
||||
runtime._message_debounce_required = False
|
||||
runtime._update_message_trigger_state = lambda message: None
|
||||
runtime._is_reply_effect_tracking_enabled = lambda: False
|
||||
return runtime
|
||||
|
||||
|
||||
def _mark_expression_learner_consumed(runtime: Any) -> None:
|
||||
learner = runtime._expression_learner
|
||||
if hasattr(learner, "set_processed_message_cache_index"):
|
||||
learner.set_processed_message_cache_index(len(runtime.message_cache))
|
||||
return
|
||||
if hasattr(learner, "_last_processed_index"):
|
||||
learner._last_processed_index = len(runtime.message_cache)
|
||||
|
||||
|
||||
async def _maybe_call_runtime_prune(runtime: Any) -> bool:
|
||||
prune_runtime_caches = getattr(runtime, "prune_runtime_caches", None)
|
||||
if not callable(prune_runtime_caches):
|
||||
return False
|
||||
|
||||
result = prune_runtime_caches()
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
return True
|
||||
|
||||
|
||||
async def probe_maisaka_message_cache(args: argparse.Namespace) -> bool:
|
||||
from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
runtime = _build_maisaka_runtime_stub(args.max_cache_size)
|
||||
print("[Maisaka 消息缓存]")
|
||||
print("批次,累计注册消息数,缓存消息数,原始消息映射数,已处理下标,MB")
|
||||
|
||||
for index in range(args.messages):
|
||||
message = PayloadMessage(f"m{index}", args.payload_size)
|
||||
await MaisakaHeartFlowChatting.register_message(runtime, message)
|
||||
if (index + 1) % args.batch_size != 0:
|
||||
continue
|
||||
|
||||
MaisakaHeartFlowChatting._collect_pending_messages(runtime)
|
||||
if args.call_prune:
|
||||
_mark_expression_learner_consumed(runtime)
|
||||
await _maybe_call_runtime_prune(runtime)
|
||||
gc.collect()
|
||||
|
||||
retained_payload = sum(len(message.payload) for message in runtime.message_cache)
|
||||
print(
|
||||
f"{(index + 1) // args.batch_size},"
|
||||
f"{index + 1},"
|
||||
f"{len(runtime.message_cache)},"
|
||||
f"{len(runtime._source_messages_by_id)},"
|
||||
f"{runtime._last_processed_index},"
|
||||
f"{retained_payload / 1024 / 1024:.2f}"
|
||||
)
|
||||
|
||||
issue_observed = len(runtime.message_cache) > args.max_cache_size
|
||||
print(f"是否观察到无界增长={_bool_cn(issue_observed)}")
|
||||
return issue_observed
|
||||
|
||||
|
||||
async def probe_heartflow_session_registry(args: argparse.Namespace) -> bool:
|
||||
from src.chat.heart_flow.heartflow_manager import HeartflowManager
|
||||
|
||||
manager = HeartflowManager()
|
||||
print("\n[Heartflow 会话注册表]")
|
||||
print("批次,累计会话数,注册表长度,锁数量,MB")
|
||||
|
||||
for index in range(args.sessions):
|
||||
session_id = f"session-{index}"
|
||||
runtime = FakeRuntime(bytearray(args.session_payload_size))
|
||||
manager.heartflow_chat_list[session_id] = runtime
|
||||
manager._chat_create_locks[session_id] = None
|
||||
if hasattr(manager, "_last_access_at"):
|
||||
manager._last_access_at[session_id] = 100.0
|
||||
|
||||
if (index + 1) % args.session_batch_size != 0:
|
||||
continue
|
||||
|
||||
retained_payload = sum(len(runtime.payload) for runtime in manager.heartflow_chat_list.values())
|
||||
print(
|
||||
f"{(index + 1) // args.session_batch_size},"
|
||||
f"{index + 1},"
|
||||
f"{len(manager.heartflow_chat_list)},"
|
||||
f"{len(manager._chat_create_locks)},"
|
||||
f"{retained_payload / 1024 / 1024:.2f}"
|
||||
)
|
||||
|
||||
if args.call_cleanup:
|
||||
cleanup_idle_chats = getattr(manager, "cleanup_idle_chats", None)
|
||||
if callable(cleanup_idle_chats):
|
||||
cleanup_now = 100.0 + (6 * 60 * 60) + 1.0
|
||||
await cleanup_idle_chats(now=cleanup_now)
|
||||
|
||||
issue_observed = len(manager.heartflow_chat_list) == args.sessions
|
||||
print(f"剩余会话数={len(manager.heartflow_chat_list)}")
|
||||
print(f"是否观察到会话未释放={_bool_cn(issue_observed)}")
|
||||
return issue_observed
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="复现")
|
||||
parser.add_argument("--messages", type=int, default=6000)
|
||||
parser.add_argument("--batch-size", type=int, default=1000)
|
||||
parser.add_argument("--payload-size", type=int, default=16 * 1024)
|
||||
parser.add_argument("--max-cache-size", type=int, default=200)
|
||||
parser.add_argument("--sessions", type=int, default=3000)
|
||||
parser.add_argument("--session-batch-size", type=int, default=500)
|
||||
parser.add_argument("--session-payload-size", type=int, default=32 * 1024)
|
||||
parser.add_argument(
|
||||
"--call-prune",
|
||||
action="store_true",
|
||||
help="每个消息批次结束后,如运行时提供裁剪hook则主动调用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--call-cleanup",
|
||||
action="store_true",
|
||||
help="填充会话注册表后,如 HeartflowManager 提供清理方法则主动调用",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
args = parse_args()
|
||||
message_issue = await probe_maisaka_message_cache(args)
|
||||
session_issue = await probe_heartflow_session_registry(args)
|
||||
print("\n[汇总]")
|
||||
print(f"消息缓存问题是否复现={_bool_cn(message_issue)}")
|
||||
print(f"会话注册表问题是否复现={_bool_cn(session_issue)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from typing import Dict
|
||||
@@ -10,10 +9,6 @@ from src.maisaka.runtime import MaisakaHeartFlowChatting
|
||||
|
||||
logger = get_logger("heartflow")
|
||||
|
||||
HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS = 6 * 60 * 60
|
||||
HEARTFLOW_RUNTIME_MAX_SESSIONS = 512
|
||||
HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS = 5 * 60
|
||||
|
||||
|
||||
class HeartflowManager:
|
||||
"""管理 session 级别的 Maisaka 心流实例。"""
|
||||
@@ -21,91 +16,25 @@ class HeartflowManager:
|
||||
def __init__(self) -> None:
|
||||
self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {}
|
||||
self._chat_create_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._last_access_at: Dict[str, float] = {}
|
||||
self._last_cleanup_at = 0.0
|
||||
|
||||
async def _stop_runtime(self, session_id: str, chat: MaisakaHeartFlowChatting) -> None:
|
||||
try:
|
||||
await chat.stop()
|
||||
except Exception as exc:
|
||||
logger.warning(f"清理心流聊天 {session_id} 时停止 runtime 失败: {exc}", exc_info=True)
|
||||
|
||||
def _prune_runtime_caches(self, chat: MaisakaHeartFlowChatting) -> None:
|
||||
prune_runtime_caches = getattr(chat, "prune_runtime_caches", None)
|
||||
if callable(prune_runtime_caches):
|
||||
prune_runtime_caches()
|
||||
|
||||
async def cleanup_idle_chats(self, *, now: float | None = None, exclude_session_ids: set[str] | None = None) -> None:
|
||||
"""清理长期空闲或超过容量的 Maisaka runtime。"""
|
||||
current_time = time.time() if now is None else now
|
||||
excluded_session_ids = exclude_session_ids or set()
|
||||
expire_before = current_time - HEARTFLOW_RUNTIME_IDLE_TTL_SECONDS
|
||||
session_ids_to_remove = [
|
||||
session_id
|
||||
for session_id, accessed_at in self._last_access_at.items()
|
||||
if accessed_at < expire_before and session_id not in excluded_session_ids
|
||||
]
|
||||
|
||||
active_count_after_idle = len(self.heartflow_chat_list) - len(set(session_ids_to_remove))
|
||||
if active_count_after_idle > HEARTFLOW_RUNTIME_MAX_SESSIONS:
|
||||
overflow_count = active_count_after_idle - HEARTFLOW_RUNTIME_MAX_SESSIONS
|
||||
active_session_ids = [
|
||||
session_id
|
||||
for session_id in self.heartflow_chat_list
|
||||
if session_id not in session_ids_to_remove
|
||||
and session_id not in excluded_session_ids
|
||||
]
|
||||
active_session_ids.sort(key=lambda session_id: self._last_access_at.get(session_id, 0.0))
|
||||
session_ids_to_remove.extend(active_session_ids[:overflow_count])
|
||||
|
||||
self._last_cleanup_at = current_time
|
||||
removed_count = 0
|
||||
for session_id in dict.fromkeys(session_ids_to_remove):
|
||||
chat = self.heartflow_chat_list.pop(session_id, None)
|
||||
self._chat_create_locks.pop(session_id, None)
|
||||
self._last_access_at.pop(session_id, None)
|
||||
if chat is None:
|
||||
continue
|
||||
await self._stop_runtime(session_id, chat)
|
||||
removed_count += 1
|
||||
if removed_count > 0:
|
||||
logger.info(f"已清理空闲心流聊天: 数量={removed_count} 剩余={len(self.heartflow_chat_list)}")
|
||||
|
||||
async def _cleanup_idle_chats_if_due(self, *, now: float, exclude_session_id: str) -> None:
|
||||
cleanup_due = now - self._last_cleanup_at >= HEARTFLOW_RUNTIME_CLEANUP_INTERVAL_SECONDS
|
||||
capacity_exceeded = len(self.heartflow_chat_list) >= HEARTFLOW_RUNTIME_MAX_SESSIONS
|
||||
if not cleanup_due and not capacity_exceeded:
|
||||
return
|
||||
await self.cleanup_idle_chats(now=now, exclude_session_ids={exclude_session_id})
|
||||
|
||||
async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting:
|
||||
"""获取或创建指定会话对应的 Maisaka runtime。"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
self._last_access_at[session_id] = current_time
|
||||
self._prune_runtime_caches(chat)
|
||||
await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id)
|
||||
return chat
|
||||
|
||||
create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock())
|
||||
async with create_lock:
|
||||
current_time = time.time()
|
||||
if chat := self.heartflow_chat_list.get(session_id):
|
||||
self._last_access_at[session_id] = current_time
|
||||
self._prune_runtime_caches(chat)
|
||||
await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id)
|
||||
return chat
|
||||
|
||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||
if not chat_session:
|
||||
raise ValueError(f"未找到 session_id={session_id} 对应的聊天流")
|
||||
|
||||
await self._cleanup_idle_chats_if_due(now=current_time, exclude_session_id=session_id)
|
||||
new_chat = MaisakaHeartFlowChatting(session_id=session_id)
|
||||
await new_chat.start()
|
||||
self.heartflow_chat_list[session_id] = new_chat
|
||||
self._last_access_at[session_id] = current_time
|
||||
return new_chat
|
||||
except Exception as exc:
|
||||
logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True)
|
||||
|
||||
@@ -941,7 +941,6 @@ class MaisakaReasoningEngine:
|
||||
"""结束并记录一轮 Maisaka 思考循环。"""
|
||||
cycle_detail.end_time = time.time()
|
||||
self._runtime.history_loop.append(cycle_detail)
|
||||
self._runtime.prune_runtime_caches()
|
||||
self._post_process_chat_history_after_cycle()
|
||||
|
||||
timer_strings = [
|
||||
|
||||
@@ -56,9 +56,6 @@ from .tool_provider import MaisakaBuiltinToolProvider
|
||||
logger = get_logger("maisaka_runtime")
|
||||
|
||||
MAX_INTERNAL_ROUNDS = 10
|
||||
MESSAGE_CACHE_MIN_RETAINED = 200
|
||||
MESSAGE_CACHE_CONTEXT_MULTIPLIER = 4
|
||||
HISTORY_LOOP_MAX_RETAINED = 256
|
||||
|
||||
|
||||
class MaisakaHeartFlowChatting:
|
||||
@@ -85,7 +82,7 @@ class MaisakaHeartFlowChatting:
|
||||
self._chat_history: list[LLMContextMessage] = []
|
||||
self.history_loop: list[CycleDetail] = []
|
||||
|
||||
# Keep recent original messages for batching, tools, and later learning.
|
||||
# Keep all original messages for batching and later learning.
|
||||
self.message_cache: list[SessionMessage] = []
|
||||
self._last_processed_index = 0
|
||||
self._internal_turn_queue: asyncio.Queue[Literal["message", "timeout"]] = asyncio.Queue()
|
||||
@@ -114,10 +111,6 @@ class MaisakaHeartFlowChatting:
|
||||
else global_config.chat.max_private_context_size
|
||||
)
|
||||
self._max_context_size = max(1, int(configured_context_size))
|
||||
self._message_cache_max_size = max(
|
||||
MESSAGE_CACHE_MIN_RETAINED,
|
||||
self._max_context_size * MESSAGE_CACHE_CONTEXT_MULTIPLIER,
|
||||
)
|
||||
self._agent_state: Literal["running", "wait", "stop"] = self._STATE_STOP
|
||||
self._pending_wait_tool_call_id: Optional[str] = None
|
||||
self._force_next_timing_continue = False
|
||||
@@ -947,54 +940,6 @@ class MaisakaHeartFlowChatting:
|
||||
def _has_pending_messages(self) -> bool:
|
||||
return self._last_processed_index < len(self.message_cache)
|
||||
|
||||
def _get_expression_learner_processed_index(self) -> int:
|
||||
learner_index = getattr(self._expression_learner, "_last_processed_index", self._last_processed_index)
|
||||
try:
|
||||
return max(0, int(learner_index))
|
||||
except (TypeError, ValueError):
|
||||
return self._last_processed_index
|
||||
|
||||
def _adjust_expression_learner_processed_index(self, removed_count: int) -> None:
|
||||
if not hasattr(self._expression_learner, "_last_processed_index"):
|
||||
return
|
||||
learner_index = self._get_expression_learner_processed_index()
|
||||
setattr(self._expression_learner, "_last_processed_index", max(0, learner_index - removed_count))
|
||||
|
||||
def _prune_processed_message_cache(self) -> None:
|
||||
"""Trim old processed messages while preserving pending and learning windows."""
|
||||
max_size = max(1, int(getattr(self, "_message_cache_max_size", MESSAGE_CACHE_MIN_RETAINED)))
|
||||
overflow_count = len(self.message_cache) - max_size
|
||||
if overflow_count <= 0:
|
||||
return
|
||||
|
||||
processed_boundary = self._last_processed_index
|
||||
removable_count = min(overflow_count, processed_boundary)
|
||||
if removable_count <= 0:
|
||||
return
|
||||
|
||||
removed_messages = self.message_cache[:removable_count]
|
||||
removed_message_ids = {message.message_id for message in removed_messages}
|
||||
del self.message_cache[:removable_count]
|
||||
self._last_processed_index = max(0, self._last_processed_index - removable_count)
|
||||
self._adjust_expression_learner_processed_index(removable_count)
|
||||
retained_message_ids = {message.message_id for message in self.message_cache}
|
||||
|
||||
for message_id in removed_message_ids:
|
||||
if message_id not in retained_message_ids:
|
||||
self._message_received_at_by_id.pop(message_id, None)
|
||||
self._source_messages_by_id.pop(message_id, None)
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 已裁剪 Maisaka 消息缓存: "
|
||||
f"移除={removable_count} 剩余={len(self.message_cache)} 上限={max_size}"
|
||||
)
|
||||
|
||||
def prune_runtime_caches(self) -> None:
|
||||
"""Apply bounded retention to runtime-only in-memory histories."""
|
||||
self._prune_processed_message_cache()
|
||||
if len(self.history_loop) > HISTORY_LOOP_MAX_RETAINED:
|
||||
del self.history_loop[: len(self.history_loop) - HISTORY_LOOP_MAX_RETAINED]
|
||||
|
||||
def _schedule_message_turn(self) -> None:
|
||||
"""为当前待处理消息安排一次内部 turn。"""
|
||||
if self._agent_state == self._STATE_WAIT:
|
||||
|
||||
Reference in New Issue
Block a user