From 697147844a67fc9a4a32b62d0e6c57fa8fbd99fb Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 9 May 2026 17:26:17 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E8=A3=81=E5=88=87=E7=BC=93?= =?UTF-8?q?=E5=AD=98=EF=BC=8C=E5=88=A0=E9=99=A4=E8=B6=85=E6=97=B6=E8=81=8A?= =?UTF-8?q?=E5=A4=A9=E6=B5=81=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=86=85=E5=AD=98?= =?UTF-8?q?=E6=BA=A2=E5=87=BA=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/test_maisaka_memory_retention.py | 105 +++++++++++++++++++++++ src/chat/heart_flow/heartflow_manager.py | 59 ++++++++++++- src/learners/expression_learner.py | 29 +++++-- src/maisaka/runtime.py | 46 ++++++++-- 4 files changed, 223 insertions(+), 16 deletions(-) create mode 100644 pytests/test_maisaka_memory_retention.py diff --git a/pytests/test_maisaka_memory_retention.py b/pytests/test_maisaka_memory_retention.py new file mode 100644 index 00000000..921302a7 --- /dev/null +++ b/pytests/test_maisaka_memory_retention.py @@ -0,0 +1,105 @@ +from types import SimpleNamespace + +import pytest +import time + +from src.chat.heart_flow import heartflow_manager as heartflow_manager_module +from src.chat.heart_flow.heartflow_manager import HEARTFLOW_ACTIVE_RETENTION_SECONDS, HeartflowManager +from src.learners.expression_learner import ExpressionLearner +from src.maisaka.runtime import MAX_RETAINED_MESSAGE_CACHE_SIZE, MaisakaHeartFlowChatting + + +def _build_runtime_with_messages(message_count: int) -> MaisakaHeartFlowChatting: + runtime = object.__new__(MaisakaHeartFlowChatting) + runtime.log_prefix = "[test]" + runtime.message_cache = [SimpleNamespace(message_id=f"msg-{index}") for index in range(message_count)] + runtime._last_processed_index = message_count + runtime._expression_learner = ExpressionLearner("session-1") + runtime._expression_learner.mark_all_processed(runtime.message_cache) + return runtime + + +def test_prune_processed_message_cache_keeps_bounded_recent_window() -> None: + runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25) + + runtime._prune_processed_message_cache() + + assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + assert runtime.message_cache[0].message_id == "msg-25" + assert runtime._last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE + assert runtime._expression_learner.last_processed_index == MAX_RETAINED_MESSAGE_CACHE_SIZE + + +def test_prune_processed_message_cache_keeps_unlearned_messages() -> None: + runtime = _build_runtime_with_messages(MAX_RETAINED_MESSAGE_CACHE_SIZE + 25) + runtime._expression_learner.discard_processed_prefix(MAX_RETAINED_MESSAGE_CACHE_SIZE + 5) + + runtime._prune_processed_message_cache() + + assert len(runtime.message_cache) == MAX_RETAINED_MESSAGE_CACHE_SIZE + 5 + assert runtime.message_cache[0].message_id == "msg-20" + assert runtime._expression_learner.last_processed_index == 0 + + +def test_collect_pending_messages_uses_single_pending_received_time() -> None: + runtime = _build_runtime_with_messages(2) + runtime._last_processed_index = 0 + runtime._oldest_pending_message_received_at = 123.0 + runtime._last_message_received_at = 456.0 + runtime._reply_latency_measurement_started_at = None + + pending_messages = runtime._collect_pending_messages() + + assert [message.message_id for message in pending_messages] == ["msg-0", "msg-1"] + assert runtime._reply_latency_measurement_started_at == 123.0 + assert runtime._oldest_pending_message_received_at is None + + +@pytest.mark.asyncio +async def test_heartflow_manager_evicts_lru_chat_over_limit(monkeypatch: pytest.MonkeyPatch) -> None: + manager = HeartflowManager() + stopped_session_ids: list[str] = [] + old_active_at = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS - 1 + + class FakeChat: + def __init__(self, session_id: str) -> None: + self.session_id = session_id + + async def stop(self) -> None: + stopped_session_ids.append(self.session_id) + + monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2) + manager.heartflow_chat_list["session-1"] = FakeChat("session-1") + manager.heartflow_chat_list["session-2"] = FakeChat("session-2") + manager.heartflow_chat_list["session-3"] = FakeChat("session-3") + manager._chat_last_active_at["session-1"] = old_active_at + manager._chat_last_active_at["session-2"] = old_active_at + manager._chat_last_active_at["session-3"] = time.time() + + await manager._evict_over_limit_chats(protected_session_id="session-3") + + assert stopped_session_ids == ["session-1"] + assert list(manager.heartflow_chat_list) == ["session-2", "session-3"] + + +@pytest.mark.asyncio +async def test_heartflow_manager_keeps_recent_chats_even_over_limit(monkeypatch: pytest.MonkeyPatch) -> None: + manager = HeartflowManager() + stopped_session_ids: list[str] = [] + + class FakeChat: + def __init__(self, session_id: str) -> None: + self.session_id = session_id + + async def stop(self) -> None: + stopped_session_ids.append(self.session_id) + + monkeypatch.setattr(heartflow_manager_module, "HEARTFLOW_MAX_ACTIVE_CHATS", 2) + for session_id in ("session-1", "session-2", "session-3"): + manager.heartflow_chat_list[session_id] = FakeChat(session_id) + manager._chat_last_active_at[session_id] = time.time() + + await manager._evict_over_limit_chats(protected_session_id="session-3") + + assert stopped_session_ids == [] + assert list(manager.heartflow_chat_list) == ["session-1", "session-2", "session-3"] diff --git a/src/chat/heart_flow/heartflow_manager.py b/src/chat/heart_flow/heartflow_manager.py index 3bbc6ec3..a31a1215 100644 --- a/src/chat/heart_flow/heartflow_manager.py +++ b/src/chat/heart_flow/heartflow_manager.py @@ -1,31 +1,39 @@ -import asyncio -import traceback - +from collections import OrderedDict from typing import Dict +import asyncio +import time +import traceback + from src.chat.message_receive.chat_manager import chat_manager from src.common.logger import get_logger from src.maisaka.runtime import MaisakaHeartFlowChatting logger = get_logger("heartflow") +HEARTFLOW_ACTIVE_RETENTION_SECONDS = 24 * 60 * 60 +HEARTFLOW_MAX_ACTIVE_CHATS = 100 + class HeartflowManager: """管理 session 级别的 Maisaka 心流实例。""" def __init__(self) -> None: - self.heartflow_chat_list: Dict[str, MaisakaHeartFlowChatting] = {} + self.heartflow_chat_list: OrderedDict[str, MaisakaHeartFlowChatting] = OrderedDict() self._chat_create_locks: Dict[str, asyncio.Lock] = {} + self._chat_last_active_at: Dict[str, float] = {} async def get_or_create_heartflow_chat(self, session_id: str) -> MaisakaHeartFlowChatting: """获取或创建指定会话对应的 Maisaka runtime。""" try: if chat := self.heartflow_chat_list.get(session_id): + self._touch_chat(session_id) return chat create_lock = self._chat_create_locks.setdefault(session_id, asyncio.Lock()) async with create_lock: if chat := self.heartflow_chat_list.get(session_id): + self._touch_chat(session_id) return chat chat_session = chat_manager.get_session_by_session_id(session_id) @@ -35,16 +43,59 @@ class HeartflowManager: new_chat = MaisakaHeartFlowChatting(session_id=session_id) await new_chat.start() self.heartflow_chat_list[session_id] = new_chat + self._touch_chat(session_id) + await self._evict_over_limit_chats(protected_session_id=session_id) return new_chat except Exception as exc: logger.error(f"创建心流聊天 {session_id} 失败: {exc}", exc_info=True) traceback.print_exc() raise + def _touch_chat(self, session_id: str) -> None: + """记录会话最近活跃时间,并维护心流实例的 LRU 顺序。""" + self._chat_last_active_at[session_id] = time.time() + self.heartflow_chat_list.move_to_end(session_id) + + async def _evict_over_limit_chats(self, *, protected_session_id: str) -> None: + """当实例数量超过上限时,仅淘汰 24 小时内无消息的旧会话。""" + while len(self.heartflow_chat_list) > HEARTFLOW_MAX_ACTIVE_CHATS: + session_id = self._find_evictable_session_id(protected_session_id=protected_session_id) + if session_id is None: + return + await self._evict_chat(session_id, reason="cache_limit") + + def _find_evictable_session_id(self, *, protected_session_id: str) -> str | None: + """按 LRU 查找超过活跃保护窗口的可淘汰会话。""" + expire_before = time.time() - HEARTFLOW_ACTIVE_RETENTION_SECONDS + for session_id in self.heartflow_chat_list: + if session_id == protected_session_id: + continue + last_active_at = self._chat_last_active_at.get(session_id, 0.0) + if last_active_at <= expire_before: + return session_id + return None + + async def _evict_chat(self, session_id: str, *, reason: str) -> None: + """停止并移除指定会话的心流实例。""" + chat = self.heartflow_chat_list.pop(session_id, None) + self._chat_last_active_at.pop(session_id, None) + lock = self._chat_create_locks.get(session_id) + if lock is not None and not lock.locked(): + self._chat_create_locks.pop(session_id, None) + if chat is None: + return + + try: + await chat.stop() + logger.info(f"已淘汰心流聊天 {session_id}: reason={reason}") + except Exception as exc: + logger.warning(f"淘汰心流聊天 {session_id} 失败: {exc}", exc_info=True) + def adjust_talk_frequency(self, session_id: str, frequency: float) -> None: """调整指定聊天流的说话频率。""" chat = self.heartflow_chat_list.get(session_id) if chat: + self._touch_chat(session_id) chat.adjust_talk_frequency(frequency) logger.info(f"已调整聊天 {session_id} 的说话频率为 {frequency}") else: diff --git a/src/learners/expression_learner.py b/src/learners/expression_learner.py index d96c451c..87aca700 100644 --- a/src/learners/expression_learner.py +++ b/src/learners/expression_learner.py @@ -160,6 +160,23 @@ class ExpressionLearner: self._last_processed_index = 0 self.min_messages_for_extraction = 10 + @property + def last_processed_index(self) -> int: + """返回表达学习已经消费到的消息缓存下标。""" + return self._last_processed_index + + def mark_all_processed(self, message_cache: List["SessionMessage"]) -> None: + """在跳过表达学习时,将现有消息标记为已处理,避免阻塞缓存裁剪。""" + self._last_processed_index = len(message_cache) + + def mark_processed_until(self, processed_end_index: int) -> None: + """将指定缓存下标之前的消息标记为已处理。""" + self._last_processed_index = max(self._last_processed_index, processed_end_index) + + def discard_processed_prefix(self, removed_count: int) -> None: + """同步 runtime 对消息缓存前缀的裁剪。""" + self._last_processed_index = max(0, self._last_processed_index - removed_count) + @staticmethod def _get_runtime_manager() -> Any: """获取插件运行时管理器。 @@ -274,7 +291,8 @@ class ExpressionLearner: jargon_miner: Optional["JargonMiner"] = None, ) -> bool: """学习表达方式""" - pending_messages = message_cache[self._last_processed_index :] + processed_end_index = len(message_cache) + pending_messages = message_cache[self._last_processed_index : processed_end_index] if not pending_messages: logger.debug("没有待处理消息") return False @@ -303,6 +321,7 @@ class ExpressionLearner: response = generation_result.response except Exception as e: logger.error(f"学习表达方式失败: {e}") + self._last_processed_index = processed_end_index return False expressions: List[Tuple[str, str, str]] @@ -336,7 +355,7 @@ class ExpressionLearner: ) if after_extract_result.aborted: logger.info(f"{self.session_id} 表达方式选择 Hook 中止") - self._last_processed_index = len(message_cache) + self._last_processed_index = processed_end_index return False after_extract_kwargs = after_extract_result.kwargs @@ -352,7 +371,7 @@ class ExpressionLearner: if not expressions: logger.info("没有可学习的表达方式") - self._last_processed_index = len(message_cache) + self._last_processed_index = processed_end_index return False logger.info(f"可学习的表达方式: {expressions}") @@ -361,7 +380,7 @@ class ExpressionLearner: learnt_expressions = self._filter_expressions(expressions, pending_messages) if not learnt_expressions: logger.info("没有可学习的表达方式通过过滤") - self._last_processed_index = len(message_cache) + self._last_processed_index = processed_end_index return False learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions) @@ -386,7 +405,7 @@ class ExpressionLearner: continue await self._upsert_expression_to_db(situation, style) - self._last_processed_index = len(message_cache) + self._last_processed_index = processed_end_index return True def _check_cached_jargons_in_messages( diff --git a/src/maisaka/runtime.py b/src/maisaka/runtime.py index eba82a58..55061917 100644 --- a/src/maisaka/runtime.py +++ b/src/maisaka/runtime.py @@ -56,6 +56,7 @@ from .tool_provider import MaisakaBuiltinToolProvider logger = get_logger("maisaka_runtime") MAX_INTERNAL_ROUNDS = 10 +MAX_RETAINED_MESSAGE_CACHE_SIZE = 200 class MaisakaHeartFlowChatting: @@ -97,7 +98,7 @@ class MaisakaHeartFlowChatting: self._deferred_message_turn_task: Optional[asyncio.Task[None]] = None self._message_debounce_seconds = 1.0 self._message_debounce_required = False - self._message_received_at_by_id: dict[str, float] = {} + self._oldest_pending_message_received_at: Optional[float] = None self._last_message_received_at = 0.0 self._talk_frequency_adjust = 1.0 self._reply_latency_measurement_started_at: Optional[float] = None @@ -311,9 +312,11 @@ class MaisakaHeartFlowChatting: self._ensure_background_tasks_running() received_at = time.time() self._last_message_received_at = received_at + if self._oldest_pending_message_received_at is None: + self._oldest_pending_message_received_at = received_at self._update_message_trigger_state(message) self.message_cache.append(message) - self._message_received_at_by_id[message.message_id] = received_at + self._prune_processed_message_cache() if self._is_reply_effect_tracking_enabled(): asyncio.create_task(self._reply_effect_tracker.observe_user_message(message)) if self._agent_state == self._STATE_RUNNING: @@ -502,6 +505,28 @@ class MaisakaHeartFlowChatting: return None + def _prune_processed_message_cache(self) -> None: + """裁剪 runtime 与表达学习器都已经消费过的旧消息。""" + excess_count = len(self.message_cache) - MAX_RETAINED_MESSAGE_CACHE_SIZE + if excess_count <= 0: + return + + removable_count = min( + excess_count, + self._last_processed_index, + self._expression_learner.last_processed_index, + ) + if removable_count <= 0: + return + + del self.message_cache[:removable_count] + self._last_processed_index = max(0, self._last_processed_index - removable_count) + self._expression_learner.discard_processed_prefix(removable_count) + logger.debug( + f"{self.log_prefix} 已清理 Maisaka 旧消息缓存: " + f"清理数量={removable_count} 保留数量={len(self.message_cache)}" + ) + def _should_trigger_message_turn_by_idle_compensation( self, *, @@ -1016,12 +1041,10 @@ class MaisakaHeartFlowChatting: # f"收集 {len(unique_messages)} 条新消息" # ) if unique_messages and self._reply_latency_measurement_started_at is None: - self._reply_latency_measurement_started_at = min( - self._message_received_at_by_id.get(message.message_id, self._last_message_received_at) - for message in unique_messages + self._reply_latency_measurement_started_at = ( + self._oldest_pending_message_received_at or self._last_message_received_at ) - for message in unique_messages: - self._message_received_at_by_id.pop(message.message_id, None) + self._oldest_pending_message_received_at = None return unique_messages async def _wait_for_message_quiet_period(self) -> None: @@ -1090,10 +1113,19 @@ class MaisakaHeartFlowChatting: async def _trigger_batch_learning(self, messages: list[SessionMessage]) -> None: """按同一批消息触发表达方式和黑话学习。""" + processed_end_index = len(self.message_cache) + if not self._enable_expression_learning: + self._expression_learner.mark_all_processed(self.message_cache) + self._prune_processed_message_cache() + return + try: await self._trigger_expression_learning(messages) except Exception as exc: logger.error(f"{self.log_prefix} 表达学习任务异常退出: {exc}") + self._expression_learner.mark_processed_until(processed_end_index) + finally: + self._prune_processed_message_cache() def _should_trigger_learning( self,