diff --git a/src/common/data_models/llm_service_data_models.py b/src/common/data_models/llm_service_data_models.py index 15b530ca..cacd3e10 100644 --- a/src/common/data_models/llm_service_data_models.py +++ b/src/common/data_models/llm_service_data_models.py @@ -66,6 +66,9 @@ class LLMResponseResult(BaseDataModel): reasoning: str = field(default_factory=str) model_name: str = field(default_factory=str) tool_calls: List[ToolCall] | None = None + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 @dataclass(slots=True) @@ -120,6 +123,9 @@ class LLMServiceResult(BaseDataModel): "response": self.completion.response, "reasoning": self.completion.reasoning, "model_name": self.completion.model_name, + "prompt_tokens": self.completion.prompt_tokens, + "completion_tokens": self.completion.completion_tokens, + "total_tokens": self.completion.total_tokens, } if self.completion.tool_calls is not None: payload["tool_calls"] = [ diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 47a30ed6..accfe355 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -34,6 +34,7 @@ from src.llm_models.model_client.base_client import ( ClientRequest, EmbeddingRequest, ResponseRequest, + UsageRecord, client_registry, ) from src.llm_models.payload_content.message import Message, MessageBuilder @@ -137,6 +138,7 @@ class LLMOrchestrator: reasoning_content: str, model_name: str, tool_calls: List[ToolCall] | None, + usage: UsageRecord | None = None, ) -> LLMResponseResult: """构建统一的文本响应结果。 @@ -154,6 +156,9 @@ class LLMOrchestrator: reasoning=reasoning_content, model_name=model_name, tool_calls=tool_calls, + prompt_tokens=usage.prompt_tokens if usage is not None else 0, + completion_tokens=usage.completion_tokens if usage is not None else 0, + total_tokens=usage.total_tokens if usage is not None else 0, ) async def generate_response_for_image( @@ -215,7 +220,13 @@ class LLMOrchestrator: endpoint="/chat/completions", time_cost=time_cost, ) - return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls) + return self._build_generation_result( + content, + reasoning_content, + model_info.name, + tool_calls, + response.usage, + ) async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult: """为语音生成转录响应。 @@ -298,7 +309,13 @@ class LLMOrchestrator: endpoint="/chat/completions", time_cost=time.time() - start_time, ) - return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) + return self._build_generation_result( + content or "", + reasoning_content, + model_info.name, + tool_calls, + response.usage, + ) async def generate_response_with_message_async( self, @@ -364,7 +381,13 @@ class LLMOrchestrator: endpoint="/chat/completions", time_cost=time_cost, ) - return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) + return self._build_generation_result( + content or "", + reasoning_content, + model_info.name, + tool_calls, + response.usage, + ) async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult: """获取嵌入向量。 diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index 839ab4f4..58724a4f 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Sequence import asyncio import json import random -import re from PIL import Image as PILImage from pydantic import BaseModel, Field as PydanticField @@ -28,7 +27,7 @@ from src.config.config import global_config from src.core.tooling import ToolRegistry, ToolSpec from src.know_u.knowledge import extract_category_ids_from_result from src.llm_models.model_client.base_client import BaseClient -from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options from src.services.llm_service import LLMServiceClient @@ -697,58 +696,29 @@ class MaisakaChatLoopService: ) @staticmethod - def _estimate_text_tokens(text: str) -> int: - """估算单段文本的输入 token 数。""" - normalized_text = text.strip() - if not normalized_text: - return 0 - - cjk_char_count = sum(1 for char in normalized_text if "\u4e00" <= char <= "\u9fff") - latin_chunks = re.findall(r"[A-Za-z0-9_]+", normalized_text) - latin_token_count = sum(max(1, (len(chunk) + 3) // 4) for chunk in latin_chunks) - punctuation_count = len(re.findall(r"[^\w\s]", normalized_text)) - whitespace_bonus = max(1, normalized_text.count("\n")) - return cjk_char_count + latin_token_count + punctuation_count + whitespace_bonus + def _format_token_count(token_count: int) -> str: + """格式化 token 数量展示文本。""" + if token_count >= 10_000: + return f"{token_count / 1000:.1f}k" + return str(token_count) @classmethod - def _estimate_request_tokens(cls, messages: Sequence[Message]) -> int: - """估算本轮请求消息的总输入 token 数。""" - total_tokens = 0 - for message in messages: - total_tokens += 4 - total_tokens += cls._estimate_text_tokens(str(message.role.value)) - if message.tool_call_id: - total_tokens += cls._estimate_text_tokens(message.tool_call_id) - if message.tool_calls: - for tool_call in message.tool_calls: - total_tokens += cls._estimate_text_tokens(getattr(tool_call, "func_name", "") or "") - total_tokens += cls._estimate_text_tokens( - json.dumps(getattr(tool_call, "args", {}) or {}, ensure_ascii=False) - ) - for part in message.parts: - if isinstance(part, TextMessagePart): - total_tokens += cls._estimate_text_tokens(part.text) - continue - if isinstance(part, ImageMessagePart): - total_tokens += max(256, len(part.image_base64) // 12) - return total_tokens - - @staticmethod def _build_prompt_stats_text( + cls, *, selected_history_count: int, built_message_count: int, - input_token_count: int, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, ) -> str: """构造本轮 prompt 的统计信息文本。""" - if input_token_count >= 10_000: - input_token_text = f"{input_token_count / 1000:.1f}k" - else: - input_token_text = str(input_token_count) return ( f"已选上下文消息数={selected_history_count} " f"大模型消息数={built_message_count} " - f"估算输入Token={input_token_text}" + f"实际输入Token={cls._format_token_count(prompt_tokens)} " + f"输出Token={cls._format_token_count(completion_tokens)} " + f"总Token={cls._format_token_count(total_tokens)}" ) async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse: @@ -764,13 +734,6 @@ class MaisakaChatLoopService: await self.ensure_chat_prompt_loaded() selected_history, selection_reason = self._select_llm_context_messages(chat_history) built_messages = self._build_request_messages(selected_history) - input_token_count = self._estimate_request_tokens(built_messages) - prompt_stats_text = self._build_prompt_stats_text( - selected_history_count=len(selected_history), - built_message_count=len(built_messages), - input_token_count=input_token_count, - ) - display_subtitle = f"{selection_reason} | {prompt_stats_text}" def message_factory(_client: BaseClient) -> List[Message]: """返回当前轮次已经构建好的请求消息。 @@ -806,7 +769,7 @@ class MaisakaChatLoopService: Panel( Group(*ordered_panels), title="MaiSaka 大模型请求 - 对话单步", - subtitle=display_subtitle, + subtitle=selection_reason, border_style="cyan", padding=(0, 1), ) @@ -820,7 +783,6 @@ class MaisakaChatLoopService: f"工具数={len(all_tools)} " f"启用打断={self._interrupt_flag is not None}" ) - logger.info(f"??Prompt??: {prompt_stats_text}") generation_result = await self._llm_chat.generate_response_with_messages( message_factory=message_factory, options=LLMGenerationOptions( @@ -833,6 +795,15 @@ class MaisakaChatLoopService: request_elapsed = perf_counter() - request_started_at logger.info(f"规划器请求完成,耗时={request_elapsed:.3f} 秒") + prompt_stats_text = self._build_prompt_stats_text( + selected_history_count=len(selected_history), + built_message_count=len(built_messages), + prompt_tokens=generation_result.prompt_tokens, + completion_tokens=generation_result.completion_tokens, + total_tokens=generation_result.total_tokens, + ) + logger.info(f"本轮Prompt统计: {prompt_stats_text}") + tool_call_summaries = [ { "调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)), diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py deleted file mode 100644 index 3d18187d..00000000 --- a/src/memory_system/chat_history_summarizer.py +++ /dev/null @@ -1,1066 +0,0 @@ -""" -聊天内容概括器 -用于累积、打包和压缩聊天记录 -""" - -import asyncio -import json -import time -import re -import difflib -import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional, Set -from dataclasses import dataclass, field -from json_repair import repair_json - -from src.chat.message_receive.message import SessionMessage -from src.common.logger import get_logger -from src.config.config import global_config -from src.common.data_models.llm_service_data_models import LLMGenerationOptions -from src.services.llm_service import LLMServiceClient -from src.services import message_service as message_api -from src.chat.utils.utils import is_bot_self -from src.person_info.person_info import Person -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.prompt.prompt_manager import prompt_manager - -logger = get_logger("chat_history_summarizer") - -HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorizer" - - -@dataclass -class MessageBatch: - """消息批次(用于触发话题检查的原始消息累积)""" - - messages: List[SessionMessage] - start_time: float - end_time: float - - -@dataclass -class TopicCacheItem: - """ - 话题缓存项 - - Attributes: - topic: 话题标题(一句话描述时间、人物、事件和主题) - messages: 与该话题相关的消息字符串列表(已经通过 build 函数转成可读文本) - participants: 涉及到的发言人昵称集合 - no_update_checks: 连续多少次“检查”没有新增内容 - """ - - topic: str - messages: List[str] = field(default_factory=list) - participants: Set[str] = field(default_factory=set) - no_update_checks: int = 0 - - -class ChatHistorySummarizer: - """聊天内容概括器""" - - def __init__(self, session_id: str, check_interval: int = 60): - """ - 初始化聊天内容概括器 - - Args: - session_id: 会话ID - check_interval: 定期检查间隔(秒),默认60秒 - """ - self.session_id = session_id - self._chat_display_name = self._get_chat_display_name() - self.log_prefix = f"[{self._chat_display_name}]" - - # 记录时间点,用于计算新消息 - self.last_check_time = time.time() - - # 记录上一次话题检查的时间,用于判断是否需要触发检查 - self.last_topic_check_time = time.time() - - # 当前累积的消息批次 - self.current_batch: Optional[MessageBatch] = None - - # 话题缓存:topic_str -> TopicCacheItem - # 在内存中维护,并通过本地文件实时持久化 - self.topic_cache: Dict[str, TopicCacheItem] = {} - self._safe_chat_id = self._sanitize_chat_id(self.session_id) - self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json" - # 注意:批次加载需要异步查询消息,所以在 start() 中调用 - - # LLM请求器,用于压缩聊天内容 - self.summarizer_llm = LLMServiceClient( - task_name="utils", request_type="chat_history_summarizer" - ) - - # 后台循环相关 - self.check_interval = check_interval # 检查间隔(秒) - self._periodic_task: Optional[asyncio.Task] = None - self._running = False - - def _get_chat_display_name(self) -> str: - """获取聊天显示名称""" - try: - chat_name = _chat_manager.get_session_name(self.session_id) - if chat_name: - return chat_name - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - except Exception: - # 如果获取失败,使用简化的chat_id显示 - if len(self.session_id) > 20: - return f"{self.session_id[:8]}..." - return self.session_id - - def _sanitize_chat_id(self, chat_id: str) -> str: - """用于生成可作为文件名的 chat_id""" - return re.sub(r"[^a-zA-Z0-9_.-]", "_", chat_id) - - def _load_topic_cache_from_disk(self): - """在启动时加载本地话题缓存(同步部分),支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - self.last_topic_check_time = data.get("last_topic_check_time", self.last_topic_check_time) - topics_data = data.get("topics", {}) - loaded_count = 0 - for topic, payload in topics_data.items(): - self.topic_cache[topic] = TopicCacheItem( - topic=topic, - messages=payload.get("messages", []), - participants=set(payload.get("participants", [])), - no_update_checks=payload.get("no_update_checks", 0), - ) - loaded_count += 1 - - if loaded_count: - logger.info(f"{self.log_prefix} 已加载 {loaded_count} 个话题缓存,继续追踪") - except Exception as e: - logger.error(f"{self.log_prefix} 加载话题缓存失败: {e}") - - async def _load_batch_from_disk(self): - """在启动时加载聊天批次,支持重启后继续""" - try: - if not self._topic_cache_file.exists(): - return - - with self._topic_cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - batch_data = data.get("current_batch") - if not batch_data: - return - - start_time = batch_data.get("start_time") - end_time = batch_data.get("end_time") - if not start_time or not end_time: - return - - # 根据时间范围重新查询消息 - messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=start_time, - end_time=end_time, - limit=0, - limit_mode="latest", - filter_mai=False, - filter_command=False, - ) - - if messages: - self.current_batch = MessageBatch( - messages=messages, - start_time=start_time, - end_time=end_time, - ) - logger.info(f"{self.log_prefix} 已恢复聊天批次,包含 {len(messages)} 条消息") - except Exception as e: - logger.error(f"{self.log_prefix} 加载聊天批次失败: {e}") - - def _persist_topic_cache(self): - """实时持久化话题缓存和聊天批次,避免重启后丢失""" - try: - # 如果既没有话题缓存也没有批次,删除缓存文件 - if not self.topic_cache and not self.current_batch: - if self._topic_cache_file.exists(): - self._topic_cache_file.unlink() - return - - HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True) - data = { - "chat_id": self.session_id, - "last_topic_check_time": self.last_topic_check_time, - "topics": { - topic: { - "messages": item.messages, - "participants": list(item.participants), - "no_update_checks": item.no_update_checks, - } - for topic, item in self.topic_cache.items() - }, - } - - # 保存当前批次的时间范围(如果有) - if self.current_batch: - data["current_batch"] = { - "start_time": self.current_batch.start_time, - "end_time": self.current_batch.end_time, - } - - with self._topic_cache_file.open("w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) - except Exception as e: - logger.error(f"{self.log_prefix} 持久化话题缓存失败: {e}") - - async def process(self, current_time: Optional[float] = None): - """ - 处理聊天内容概括 - - Args: - current_time: 当前时间戳,如果为None则使用time.time() - """ - if current_time is None: - current_time = time.time() - - try: - # 获取从上次检查时间到当前时间的新消息 - new_messages = message_api.get_messages_by_time_in_chat( - chat_id=self.session_id, - start_time=self.last_check_time, - end_time=current_time, - limit=0, - limit_mode="latest", - filter_mai=False, # 不过滤bot消息,因为需要检查bot是否发言 - filter_command=False, - ) - - if not new_messages: - # 没有新消息,检查是否需要进行“话题检查” - if self.current_batch and self.current_batch.messages: - await self._check_and_run_topic_check(current_time) - self.last_check_time = current_time - return - - logger.debug( - f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}" - ) - - # 有新消息,更新最后检查时间 - self.last_check_time = current_time - - # 如果有当前批次,添加新消息 - if self.current_batch: - before_count = len(self.current_batch.messages) - self.current_batch.messages.extend(new_messages) - self.current_batch.end_time = current_time - logger.info( - f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息" - ) - # 更新批次后持久化 - self._persist_topic_cache() - else: - # 创建新批次 - self.current_batch = MessageBatch( - messages=new_messages, - start_time=new_messages[0].timestamp.timestamp() if new_messages else current_time, - end_time=current_time, - ) - logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息") - # 创建批次后持久化 - self._persist_topic_cache() - - # 检查是否需要触发“话题检查” - await self._check_and_run_topic_check(current_time) - - except Exception as e: - logger.error(f"{self.log_prefix} 处理聊天内容概括时出错: {e}") - import traceback - - traceback.print_exc() - - async def _check_and_run_topic_check(self, current_time: float): - """ - 检查是否需要进行一次“话题检查” - - 触发条件: - - 当前批次消息数 >= 100,或者 - - 距离上一次检查的时间 > 3600 秒(1小时) - """ - if not self.current_batch or not self.current_batch.messages: - return - - messages = self.current_batch.messages - message_count = len(messages) - time_since_last_check = current_time - self.last_topic_check_time - - # 格式化时间差显示 - if time_since_last_check < 60: - time_str = f"{time_since_last_check:.1f}秒" - elif time_since_last_check < 3600: - time_str = f"{time_since_last_check / 60:.1f}分钟" - else: - time_str = f"{time_since_last_check / 3600:.1f}小时" - - logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}") - - # 检查"话题检查"触发条件 - should_check = False - - # 从配置中获取阈值 - message_threshold = global_config.memory.chat_history_topic_check_message_threshold - time_threshold_hours = global_config.memory.chat_history_topic_check_time_hours - min_messages = global_config.memory.chat_history_topic_check_min_messages - time_threshold_seconds = time_threshold_hours * 3600 - - # 条件1: 消息数量达到阈值,触发一次检查 - if message_count >= message_threshold: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: {message_threshold}条)" - ) - - # 条件2: 距离上一次检查超过时间阈值且消息数量达到最小阈值,触发一次检查 - elif time_since_last_check > time_threshold_seconds and message_count >= min_messages: - should_check = True - logger.info( - f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: {time_threshold_hours}小时)且消息数量达到 {message_count} 条(阈值: {min_messages}条)" - ) - - if should_check: - await self._run_topic_check_and_update_cache(messages) - # 本批次已经被处理为话题信息,可以清空 - self.current_batch = None - # 更新上一次检查时间,并持久化 - self.last_topic_check_time = current_time - self._persist_topic_cache() - - async def _run_topic_check_and_update_cache(self, messages: List[SessionMessage]): - """ - 执行一次“话题检查”: - 1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次; - 2. 将消息编号并转成字符串,构造 LLM Prompt; - 3. 把历史话题标题列表放入 Prompt,要求 LLM: - - 识别当前聊天中的话题(1 个或多个); - - 为每个话题选出相关消息编号; - - 若话题属于历史话题,则沿用原话题标题; - 4. LLM 返回 JSON:多个 {topic, message_indices}; - 5. 更新本地话题缓存,并根据规则触发“话题打包存储”。 - """ - if not messages: - return - - start_time = messages[0].timestamp.timestamp() - end_time = messages[-1].timestamp.timestamp() - - logger.info( - f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推) - # 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言, - # 说明 bot 没有参与这段对话,不应该记录 - has_bot_message = any( - is_bot_self(msg.platform, msg.message_info.user_info.user_id) for msg in messages - ) - - if not has_bot_message: - logger.info( - f"{self.log_prefix} 当前批次内无 Bot 发言,丢弃本次检查 | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - return - - # 2. 构造编号后的消息字符串和参与者信息 - numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = ( - self._build_numbered_messages_for_llm(messages) - ) - - # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次) - existing_topics = list(self.topic_cache.keys()) - max_retries = 3 - attempt = 0 - success = False - topic_to_indices: Dict[str, List[int]] = {} - - while attempt < max_retries: - attempt += 1 - success, topic_to_indices = await self._analyze_topics_with_llm( - numbered_lines=numbered_lines, - existing_topics=existing_topics, - ) - - if success and topic_to_indices: - if attempt > 1: - logger.info( - f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}" - ) - break - - logger.warning( - f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败" - + ("" if attempt >= max_retries else ",准备重试") - ) - - if not success or not topic_to_indices: - logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃") - # 即使识别失败,也认为是一次"检查",但不更新 no_update_checks(保持原状) - return - - # 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题) - topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9) - - # 应用话题映射:将相似的新话题标题替换为历史话题标题 - if topic_mapping: - new_topic_to_indices: Dict[str, List[int]] = {} - for new_topic, indices in topic_to_indices.items(): - # 如果这个新话题需要映射到历史话题 - if new_topic in topic_mapping: - historical_topic = topic_mapping[new_topic] - # 如果历史话题已经存在,合并消息索引 - if historical_topic in new_topic_to_indices: - # 合并索引并去重 - combined_indices = list(set(new_topic_to_indices[historical_topic] + indices)) - new_topic_to_indices[historical_topic] = combined_indices - else: - new_topic_to_indices[historical_topic] = indices - else: - # 不需要映射,保持原样 - new_topic_to_indices[new_topic] = indices - topic_to_indices = new_topic_to_indices - - # 4. 统计哪些话题在本次检查中有新增内容 - updated_topics: Set[str] = set() - - for topic, indices in topic_to_indices.items(): - if not indices: - continue - - item = self.topic_cache.get(topic) - if not item: - # 新话题 - item = TopicCacheItem(topic=topic) - self.topic_cache[topic] = item - - # 收集属于该话题的消息文本(不带编号) - topic_msg_texts: List[str] = [] - new_participants: Set[str] = set() - for idx in indices: - msg_text = index_to_msg_text.get(idx) - if not msg_text: - continue - topic_msg_texts.append(msg_text) - new_participants.update(index_to_participants.get(idx, set())) - - if not topic_msg_texts: - continue - - # 将本次检查中属于该话题的所有消息合并为一个字符串(不带编号) - merged_text = "\n".join(topic_msg_texts) - item.messages.append(merged_text) - item.participants.update(new_participants) - # 本次检查中该话题有更新,重置计数 - item.no_update_checks = 0 - updated_topics.add(topic) - - # 5. 对于本次没有更新的历史话题,no_update_checks + 1 - for topic, item in list(self.topic_cache.items()): - if topic not in updated_topics: - item.no_update_checks += 1 - - # 6. 检查是否有话题需要打包存储 - # 从配置中获取阈值 - no_update_checks_threshold = global_config.memory.chat_history_finalize_no_update_checks - message_count_threshold = global_config.memory.chat_history_finalize_message_count - - topics_to_finalize: List[str] = [] - for topic, item in self.topic_cache.items(): - if item.no_update_checks >= no_update_checks_threshold: - logger.info( - f"{self.log_prefix} 话题[{topic}] 连续 {no_update_checks_threshold} 次检查无新增内容,触发打包存储" - ) - topics_to_finalize.append(topic) - continue - if len(item.messages) > message_count_threshold: - logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 {message_count_threshold},触发打包存储") - topics_to_finalize.append(topic) - - for topic in topics_to_finalize: - item = self.topic_cache.get(topic) - if not item: - continue - try: - await self._finalize_and_store_topic( - topic=topic, - item=item, - # 这里的时间范围尽量覆盖最近一次检查的区间 - start_time=start_time, - end_time=end_time, - ) - finally: - # 无论成功与否,都从缓存中删除,避免重复 - self.topic_cache.pop(topic, None) - - def _find_most_similar_topic( - self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9 - ) -> Optional[tuple[str, float]]: - """ - 查找与给定新话题最相似的历史话题 - - Args: - new_topic: 新话题标题 - existing_topics: 历史话题标题列表 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题,返回(历史话题标题, 相似度), - 否则返回None - """ - if not existing_topics: - return None - - best_match = None - best_similarity = 0.0 - - for existing_topic in existing_topics: - similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio() - if similarity > best_similarity: - best_similarity = similarity - best_match = existing_topic - - # 如果相似度达到阈值,返回匹配结果 - if best_match and best_similarity >= similarity_threshold: - return (best_match, best_similarity) - - return None - - def _build_topic_mapping( - self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9 - ) -> Dict[str, str]: - """ - 构建新话题到历史话题的映射(如果相似度>=阈值) - - Args: - topic_to_indices: 新话题到消息索引的映射 - similarity_threshold: 相似度阈值,默认0.9(90%) - - Returns: - Dict[str, str]: 新话题 -> 历史话题的映射字典 - """ - existing_topics_list = list(self.topic_cache.keys()) - topic_mapping: Dict[str, str] = {} - - for new_topic in topic_to_indices.keys(): - # 如果新话题已经在历史话题中,不需要检查 - if new_topic in existing_topics_list: - continue - - # 查找最相似的历史话题 - result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold) - if result: - historical_topic, similarity = result - topic_mapping[new_topic] = historical_topic - logger.info( - f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题" - ) - - return topic_mapping - - def _build_numbered_messages_for_llm( - self, messages: List[SessionMessage] - ) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]: - """ - 将消息转为带编号的字符串,供 LLM 选择使用。 - - 返回: - numbered_lines: ["1. xxx", "2. yyy", ...] # 带编号,用于 LLM 选择 - index_to_msg_str: idx -> "idx. xxx" # 带编号,用于 LLM 选择 - index_to_msg_text: idx -> "xxx" # 不带编号,用于最终存储 - index_to_participants: idx -> {nickname1, nickname2, ...} - """ - numbered_lines: List[str] = [] - index_to_msg_str: Dict[int, str] = {} - index_to_msg_text: Dict[int, str] = {} # 不带编号的消息文本 - index_to_participants: Dict[int, Set[str]] = {} - - for idx, msg in enumerate(messages, start=1): - # 使用 build_readable_messages 生成可读文本 - try: - text = message_api.build_readable_messages( - messages=[msg], - replace_bot_name=True, - timestamp_mode="normal_no_YMD", - read_mark=0.0, - truncate=False, - show_actions=False, - ).strip() - except Exception: - # 回退到简单文本 - text = getattr(msg, "processed_plain_text", "") or "" - - # 获取发言人昵称 - participants: Set[str] = set() - try: - platform = msg.platform - user_id = msg.message_info.user_info.user_id - if platform and user_id: - person = Person(platform=platform, user_id=user_id) - if person.person_name: - participants.add(person.person_name) - except Exception: - pass - - # 带编号的字符串(用于 LLM 选择) - line = f"{idx}. {text}" - numbered_lines.append(line) - index_to_msg_str[idx] = line - # 不带编号的文本(用于最终存储) - index_to_msg_text[idx] = text - index_to_participants[idx] = participants - - return numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants - - async def _analyze_topics_with_llm( - self, - numbered_lines: List[str], - existing_topics: List[str], - ) -> tuple[bool, Dict[str, List[int]]]: - """ - 使用 LLM 识别本次检查中的话题,并为每个话题选择相关消息编号。 - - 要求: - - 话题用一句话清晰描述正在发生的事件,包括时间、人物、主要事件和主题; - - 可以有 1 个或多个话题; - - 若某个话题与历史话题列表中的某个话题是同一件事,请直接使用历史话题的字符串; - - 输出 JSON,格式: - [ - { - "topic": "话题标题字符串", - "message_indices": [1, 2, 5] - }, - ... - ] - """ - if not numbered_lines: - return False, {} - - history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)" - messages_block = "\n".join(numbered_lines) - - prompt_template = prompt_manager.get_prompt("hippo_topic_analysis") - prompt_template.add_context("history_topics_block", history_topics_block) - prompt_template.add_context("messages_block", messages_block) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response( - prompt=prompt, - options=LLMGenerationOptions(temperature=0.3), - ) - response = generation_result.response - - logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") - logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") - - # 尝试从响应中提取JSON代码块 - json_str = None - json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - - if matches: - # 找到JSON代码块,使用第一个匹配 - json_str = matches[0].strip() - else: - # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置 - # 查找第一个 [ 和最后一个 ] - start_idx = response.find("[") - end_idx = response.rfind("]") - if start_idx != -1 and end_idx != -1 and end_idx > start_idx: - json_str = response[start_idx : end_idx + 1].strip() - else: - # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记) - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 使用json_repair修复可能的JSON错误 - if json_str: - try: - repaired_json = repair_json(json_str) - result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json - except Exception as repair_error: - # 如果repair失败,尝试直接解析 - logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}") - result = json.loads(json_str) - else: - raise ValueError("无法从响应中提取JSON内容") - - if not isinstance(result, list): - logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}") - return False, {} - - topic_to_indices: Dict[str, List[int]] = {} - for item in result: - if not isinstance(item, dict): - continue - topic = item.get("topic") - indices = item.get("message_indices") or item.get("messages") or [] - if not topic or not isinstance(topic, str): - continue - if isinstance(indices, list): - valid_indices: List[int] = [] - for v in indices: - try: - iv = int(v) - if iv > 0: - valid_indices.append(iv) - except (TypeError, ValueError): - continue - if valid_indices: - topic_to_indices[topic] = valid_indices - - return True, topic_to_indices - - except Exception as e: - logger.error(f"{self.log_prefix} 话题识别 LLM 调用或解析失败: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - return False, {} - - async def _finalize_and_store_topic( - self, - topic: str, - item: TopicCacheItem, - start_time: float, - end_time: float, - ): - """ - 对某个话题进行最终打包存储: - 1. 将 messages(list[str]) 拼接为 original_text; - 2. 使用 LLM 对 original_text 进行总结,得到 summary 和 keywords,theme 直接使用话题字符串; - 3. 写入数据库 ChatHistory; - 4. 完成后,调用方会从缓存中删除该话题。 - """ - if not item.messages: - logger.info(f"{self.log_prefix} 话题[{topic}] 无消息内容,跳过打包") - return - - original_text = "\n".join(item.messages) - - logger.info( - f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}" - ) - - # 使用 LLM 进行总结(基于话题名),带重试机制 - max_retries = 3 - attempt = 0 - success = False - keywords = [] - summary = "" - - while attempt < max_retries: - attempt += 1 - success, keywords, summary = await self._compress_with_llm(original_text, topic) - - if success and keywords and summary: - # 成功获取到有效的 keywords 和 summary - if attempt > 1: - logger.info(f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功") - break - - if attempt < max_retries: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试") - else: - logger.error(f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储") - - if not success or not keywords or not summary: - logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库") - return - - participants = list(item.participants) - - await self._store_to_database( - start_time=start_time, - end_time=end_time, - original_text=original_text, - participants=participants, - theme=topic, # 主题直接使用话题名 - keywords=keywords, - summary=summary, - ) - - logger.info( - f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}" - ) - - async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]: - """ - 使用LLM压缩聊天内容(用于单个话题的最终总结) - - Args: - original_text: 聊天记录原文 - topic: 话题名称 - - Returns: - tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括) - """ - prompt_template = prompt_manager.get_prompt("hippo_topic_summary") - prompt_template.add_context("topic", topic) - prompt_template.add_context("original_text", original_text) - prompt = await prompt_manager.render_prompt(prompt_template) - - try: - generation_result = await self.summarizer_llm.generate_response(prompt=prompt) - response = generation_result.response - - # 解析JSON响应 - json_str = response.strip() - json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE) - json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE) - json_str = json_str.strip() - - # 查找JSON对象的开始与结束 - start_idx = json_str.find("{") - if start_idx == -1: - raise ValueError("未找到JSON对象开始标记") - - end_idx = json_str.rfind("}") - if end_idx == -1 or end_idx <= start_idx: - logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复") - extracted_json = json_str[start_idx:] - else: - extracted_json = json_str[start_idx : end_idx + 1] - - def _parse_with_quote_fix(payload: str) -> Dict[str, Any]: - fixed_chars: List[str] = [] - in_string = False - escape_next = False - i = 0 - while i < len(payload): - char = payload[i] - if escape_next: - fixed_chars.append(char) - escape_next = False - elif char == "\\": - fixed_chars.append(char) - escape_next = True - elif char == '"' and not escape_next: - fixed_chars.append(char) - in_string = not in_string - elif in_string and char in {"“", "”"}: - # 在字符串值内部,将中文引号替换为转义的英文引号 - fixed_chars.append('\\"') - else: - fixed_chars.append(char) - i += 1 - - repaired = "".join(fixed_chars) - return json.loads(repaired) - - try: - result = json.loads(extracted_json) - except json.JSONDecodeError: - try: - repaired_json = repair_json(extracted_json) - if isinstance(repaired_json, str): - result = json.loads(repaired_json) - else: - result = repaired_json - except Exception as repair_error: - logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}") - result = _parse_with_quote_fix(extracted_json) - - keywords = result.get("keywords", []) - summary = result.get("summary", "") - - # 检查必需字段是否为空 - if not keywords or not summary: - logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段,原文\n{response}") - # 返回失败,和模型出错一样,让上层进行重试 - return False, [], "" - - # 确保keywords是列表 - if isinstance(keywords, str): - keywords = [keywords] - - return True, keywords, summary - - except Exception as e: - logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}") - logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}") - # 返回失败标志和默认值 - return False, [], "压缩失败,无法生成概括" - - async def _store_to_database( - self, - start_time: float, - end_time: float, - original_text: str, - participants: List[str], - theme: str, - keywords: List[str], - summary: str, - ): - """存储到数据库""" - try: - from src.common.database.database_model import ChatHistory - from src.services import database_service as database_api - - # 准备数据 - data = { - "session_id": self.session_id, - "start_timestamp": datetime.fromtimestamp(start_time), - "end_timestamp": datetime.fromtimestamp(end_time), - "original_messages": original_text, - "participants": json.dumps(participants, ensure_ascii=False), - "theme": theme, - "keywords": json.dumps(keywords, ensure_ascii=False), - "summary": summary, - "query_count": 0, - "query_forget_count": 0, - } - - saved_record = await database_api.db_save( - ChatHistory, - data=data, - ) - - if saved_record: - logger.debug(f"{self.log_prefix} 成功存储聊天历史记录到数据库") - else: - logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - - # 同时导入到LPMM知识库 - if global_config.lpmm_knowledge.enable: - await self._import_to_lpmm_knowledge( - theme=theme, - summary=summary, - participants=participants, - original_text=original_text, - ) - - except Exception as e: - logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}") - import traceback - - traceback.print_exc() - raise - - async def _import_to_lpmm_knowledge( - self, - theme: str, - summary: str, - participants: List[str], - original_text: str, - ): - """ - 将聊天历史总结导入到LPMM知识库 - - Args: - theme: 话题主题 - summary: 概括内容 - participants: 参与者列表 - original_text: 原始文本(可能很长,需要截断) - """ - try: - from src.chat.knowledge.lpmm_ops import lpmm_ops - - # 构造要导入的文本内容 - # 格式:主题 + 概括 + 参与者信息 + 原始内容摘要 - # 注意:使用单换行符连接,确保整个内容作为一段导入,不被LPMM分段 - content_parts = [] - - # 1. 话题主题 - # if theme: - # content_parts.append(f"话题:{theme}") - - # 2. 概括内容 - if summary: - content_parts.append(f"概括:{summary}") - - # 3. 参与者信息 - if participants: - participants_text = "、".join(participants) - content_parts.append(f"参与者:{participants_text}") - - # 4. 原始文本摘要(如果原始文本太长,只取前500字) - # if original_text: - # # 截断原始文本,避免过长 - # max_original_length = 500 - # if len(original_text) > max_original_length: - # truncated_text = original_text[:max_original_length] + "..." - # content_parts.append(f"原始内容摘要:{truncated_text}") - # else: - # content_parts.append(f"原始内容:{original_text}") - - # 将所有部分合并为一个完整段落(使用单换行符,避免被LPMM分段) - # LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段 - content_to_import = "\n".join(content_parts) - - if not content_to_import.strip(): - logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库") - return - - # 调用lpmm_ops导入 - result = await lpmm_ops.add_content(text=content_to_import, auto_split=False) - - if result["status"] == "success": - logger.info( - f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}" - ) - else: - logger.warning( - f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}" - ) - - except Exception as e: - # 导入失败不应该影响数据库存储,只记录错误 - logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True) - - async def start(self): - """启动后台定期检查循环""" - if self._running: - logger.warning(f"{self.log_prefix} 后台循环已在运行,无需重复启动") - return - - # 加载聊天批次(如果有) - await self._load_batch_from_disk() - - self._running = True - self._periodic_task = asyncio.create_task(self._periodic_check_loop()) - logger.info(f"{self.log_prefix} 已启动后台定期检查循环 | 检查间隔: {self.check_interval}秒") - - async def stop(self): - """停止后台定期检查循环""" - self._running = False - if self._periodic_task: - self._periodic_task.cancel() - try: - await self._periodic_task - except asyncio.CancelledError: - pass - self._periodic_task = None - logger.info(f"{self.log_prefix} 已停止后台定期检查循环") - - async def _periodic_check_loop(self): - """后台定期检查循环""" - try: - while self._running: - # 执行一次检查 - await self.process() - - # 等待指定间隔后再次检查 - await asyncio.sleep(self.check_interval) - except asyncio.CancelledError: - logger.info(f"{self.log_prefix} 后台检查循环被取消") - raise - except Exception as e: - logger.error(f"{self.log_prefix} 后台检查循环出错: {e}") - import traceback - - traceback.print_exc() - self._running = False diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py deleted file mode 100644 index 5bc6a3a1..00000000 --- a/src/memory_system/memory_retrieval.py +++ /dev/null @@ -1,1046 +0,0 @@ -import contextlib -import time -import json -import asyncio -from datetime import datetime -from typing import List, Dict, Any, Optional, Tuple, Callable -from src.common.logger import get_logger -from src.config.config import global_config -from src.prompt.prompt_manager import prompt_manager -from src.services import llm_service as llm_api -from sqlmodel import select, col -from src.common.database.database import get_db_session -from src.common.database.database_model import ThinkingQuestion -from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools -from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon - -logger = get_logger("memory_retrieval") - -THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 36000 # 未找到答案记录保留时长 -THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 3000 # 清理频率 -_last_not_found_cleanup_ts: float = 0.0 - - -def _cleanup_stale_not_found_thinking_back() -> None: - """定期清理过期的未找到答案记录""" - global _last_not_found_cleanup_ts - - now = time.time() - if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS: - return - - threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS - try: - with get_db_session() as session: - statement = select(ThinkingQuestion).where( - col(ThinkingQuestion.found_answer).is_(False) - & (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time)) - ) - records = session.exec(statement).all() - for record in records: - session.delete(record) - if records: - logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条") - _last_not_found_cleanup_ts = now - except Exception as e: - logger.error(f"清理未找到答案的thinking_back记录失败: {e}") - - -def init_memory_retrieval_sys(): - """初始化记忆检索相关工具""" - # 注册所有工具 - init_all_tools() - - -def _log_conversation_messages( - conversation_messages: List[Message], - head_prompt: Optional[str] = None, - final_status: Optional[str] = None, -) -> None: - """输出对话消息列表的日志 - - Args: - conversation_messages: 对话消息列表 - head_prompt: 第一条系统消息(head_prompt)的内容,可选 - final_status: 最终结果状态描述(例如:找到答案/未找到答案),可选 - """ - if not global_config.debug.show_memory_prompt: - return - - log_lines: List[str] = [] - - # 如果有head_prompt,先添加为第一条消息 - if head_prompt: - msg_info = "========================================\n[消息 1] 角色: System\n-----------------------------" - msg_info += f"\n{head_prompt}" - log_lines.append(msg_info) - start_idx = 2 - else: - start_idx = 1 - - if not conversation_messages and not head_prompt: - return - - for idx, msg in enumerate(conversation_messages, start_idx): - role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - # 构建单条消息的日志信息 - # msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------" - msg_info = ( - f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------" - ) - - # if full_content: - # msg_info += f"\n{full_content}" - if msg.content: - msg_info += f"\n{msg.content}" - - if msg.tool_calls: - msg_info += f"\n 工具调用: {len(msg.tool_calls)}个" - for tool_call in msg.tool_calls: - msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}" - - # if msg.tool_call_id: - # msg_info += f"\n 工具调用ID: {msg.tool_call_id}" - - log_lines.append(msg_info) - - total_count = len(conversation_messages) + (1 if head_prompt else 0) - log_text = f"消息列表 (共{total_count}条):{''.join(log_lines)}" - if final_status: - log_text += f"\n\n[最终结果] {final_status}" - logger.info(log_text) - - -async def _react_agent_solve_question( - chat_id: str, - max_iterations: int = 5, - timeout: float = 30.0, - initial_info: str = "", - chat_history: str = "", -) -> Tuple[bool, str, List[Dict[str, Any]], bool]: - """使用ReAct架构的Agent来解决问题 - - Args: - chat_id: 聊天ID - max_iterations: 最大迭代次数 - timeout: 超时时间(秒) - initial_info: 初始信息,将作为collected_info的初始值 - chat_history: 聊天记录,将传递给 ReAct Agent prompt - - Returns: - Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) - """ - start_time = time.time() - collected_info = initial_info or "" - # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流 - try: - chat_name = _chat_manager.get_session_name(chat_id) or chat_id - except Exception: - chat_name = chat_id - react_log_prefix = f"[{chat_name}] " - thinking_steps = [] - is_timeout = False - conversation_messages: List[Message] = [] - first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt(用于日志显示) - last_tool_name: Optional[str] = None # 记录最后一次使用的工具名称 - - # 使用 while 循环,支持额外迭代 - iteration = 0 - max_iterations_with_extra = max_iterations - while iteration < max_iterations_with_extra: - # 检查超时 - if time.time() - start_time > timeout: - logger.warning(f"ReAct Agent超时,已迭代{iteration}次") - is_timeout = True - break - - # 获取工具注册器 - tool_registry = get_tool_registry() - - # 获取bot_name - bot_name = global_config.bot.nickname - - # 获取当前时间 - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - - # 计算剩余迭代次数 - current_iteration = iteration + 1 - remaining_iterations = max_iterations - current_iteration - - # 提取函数调用中参数的值,支持单引号和双引号 - def extract_quoted_content(text, func_name, param_name): - """从文本中提取函数调用中参数的值,支持单引号和双引号 - - Args: - text: 要搜索的文本 - func_name: 函数名,如 'return_information' - param_name: 参数名,如 'information' - - Returns: - 提取的参数值,如果未找到则返回None - """ - if not text: - return None - - # 查找函数调用位置(不区分大小写) - func_pattern = func_name.lower() - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 查找参数名和等号 - param_pattern = f"{param_name}=" - param_pos = text_lower.find(param_pattern, func_pos) - if param_pos == -1: - return None - - # 跳过参数名、等号和空白 - start_pos = param_pos + len(param_pattern) - while start_pos < len(text) and text[start_pos] in " \t\n": - start_pos += 1 - - if start_pos >= len(text): - return None - - # 确定引号类型 - quote_char = text[start_pos] - if quote_char not in ['"', "'"]: - return None - - # 查找匹配的结束引号(考虑转义) - end_pos = start_pos + 1 - while end_pos < len(text): - if text[end_pos] == quote_char: - # 检查是否是转义的引号 - if end_pos > start_pos + 1 and text[end_pos - 1] == "\\": - end_pos += 1 - continue - # 找到匹配的引号 - content = text[start_pos + 1 : end_pos] - # 处理转义字符 - content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\") - return content - end_pos += 1 - - return None - - # 正常迭代:使用head_prompt决定调用哪些工具(包含return_information工具) - tool_definitions = tool_registry.get_tool_definitions() - # tool_names = [tool_def["name"] for tool_def in tool_definitions] - # logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)") - - # head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个 - if first_head_prompt is None: - # 第一次构建,使用初始的collected_info(即initial_info) - initial_collected_info = initial_info or "" - # 使用 LPMM 知识库检索 prompt - first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm") - first_head_prompt_template.add_context("bot_name", bot_name) - first_head_prompt_template.add_context("time_now", time_now) - first_head_prompt_template.add_context("chat_history", chat_history) - first_head_prompt_template.add_context("collected_info", initial_collected_info) - first_head_prompt_template.add_context("current_iteration", str(current_iteration)) - first_head_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) - first_head_prompt_template.add_context("max_iterations", str(max_iterations)) - first_head_prompt = await prompt_manager.render_prompt(first_head_prompt_template) - - # 后续迭代都复用第一次构建的head_prompt - head_prompt = first_head_prompt - - def _build_messages( - _client, - *, - _head_prompt: str = head_prompt, - _conversation_messages: List[Message] = conversation_messages, - ): - messages: List[Message] = [] - - system_builder = MessageBuilder() - system_builder.set_role(RoleType.System) - system_builder.add_text_content(_head_prompt) - messages.append(system_builder.build()) - - messages.extend(_conversation_messages) - - return messages - - message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues] - generation_result = await llm_api.generate( - llm_api.LLMServiceRequest( - task_name="utils", - request_type="memory.react", - message_factory=message_factory_fn, # type: ignore[arg-type] - tool_options=tool_definitions, - ) - ) - success = generation_result.success - response = generation_result.completion.response - reasoning_content = generation_result.completion.reasoning - tool_calls = generation_result.completion.tool_calls - - # logger.info( - # f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}" - # ) - - if not success: - logger.error(f"ReAct Agent LLM调用失败: {response}") - break - - # 注意:这里会检查return_information工具调用,如果检测到return_information工具,会根据information参数决定返回信息或退出查询 - - assistant_message: Optional[Message] = None - if tool_calls: - assistant_builder = MessageBuilder() - assistant_builder.set_role(RoleType.Assistant) - if response and response.strip(): - assistant_builder.add_text_content(response) - assistant_builder.set_tool_calls(tool_calls) - assistant_message = assistant_builder.build() - elif response and response.strip(): - assistant_builder = MessageBuilder() - assistant_builder.set_role(RoleType.Assistant) - assistant_builder.add_text_content(response) - assistant_message = assistant_builder.build() - - # 记录思考步骤 - step: Dict[str, Any] = { - "iteration": iteration + 1, - "thought": response, - "actions": [], - "observations": [], - } - - if assistant_message: - conversation_messages.append(assistant_message) - - # 记录思考过程到collected_info中 - if reasoning_content or response: - thought_summary = reasoning_content or (response[:200] if response else "") - if thought_summary: - collected_info += f"\n[思考] {thought_summary}\n" - - # 处理工具调用 - if not tool_calls: - # 如果没有工具调用,检查响应文本中是否包含return_information函数调用格式或JSON格式 - if response and response.strip(): - # 首先尝试解析JSON格式的return_information - def parse_json_return_information(text: str): - """从文本中解析JSON格式的return_information,返回information字符串,如果未找到则返回None""" - if not text: - return None, None - - try: - # 尝试提取JSON对象(可能包含在代码块中或直接是JSON) - json_text = text.strip() - - # 如果包含代码块标记,提取JSON部分 - if "```json" in json_text: - start = json_text.find("```json") + 7 - end = json_text.find("```", start) - if end != -1: - json_text = json_text[start:end].strip() - elif "```" in json_text: - start = json_text.find("```") + 3 - end = json_text.find("```", start) - if end != -1: - json_text = json_text[start:end].strip() - - # 尝试解析JSON - data = json.loads(json_text) - - # 检查是否包含return_information字段 - if isinstance(data, dict) and "return_information" in data: - information = data.get("information", "") - return information - except (json.JSONDecodeError, ValueError, TypeError): - # 如果JSON解析失败,尝试在文本中查找JSON对象 - with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError): - # 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取) - first_brace = text.find("{") - if first_brace != -1: - # 从第一个 { 开始,找到匹配的 } - brace_count = 0 - json_end = -1 - for i in range(first_brace, len(text)): - if text[i] == "{": - brace_count += 1 - elif text[i] == "}": - brace_count -= 1 - if brace_count == 0: - json_end = i + 1 - break - - if json_end != -1: - json_text = text[first_brace:json_end] - data = json.loads(json_text) - if isinstance(data, dict) and "return_information" in data: - information = data.get("information", "") - return information - - return None - - # 尝试从文本中解析return_information函数调用 - def parse_return_information_from_text(text: str): - """从文本中解析return_information函数调用,返回information字符串,如果未找到则返回None""" - if not text: - return None - - # 查找return_information函数调用位置(不区分大小写) - func_pattern = "return_information" - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 解析information参数(字符串,使用extract_quoted_content) - information = extract_quoted_content(text, "return_information", "information") - - # 如果information存在(即使是空字符串),也返回它 - return information - - # 首先尝试解析JSON格式 - parsed_information_json = parse_json_return_information(response) - is_json_format = parsed_information_json is not None - - # 如果JSON解析成功,使用JSON结果 - if is_json_format: - parsed_information = parsed_information_json - else: - # 如果JSON解析失败,尝试解析函数调用格式 - parsed_information = parse_return_information_from_text(response) - - if parsed_information is not None or is_json_format: - # 检测到return_information格式(可能是JSON格式或函数调用格式) - format_type = "JSON格式" if is_json_format else "函数调用格式" - # 返回信息(即使为空字符串也返回) - step["actions"].append( - { - "action_type": "return_information", - "action_params": {"information": parsed_information or ""}, - } - ) - parsed_info_text = parsed_information if isinstance(parsed_information, str) else "" - if parsed_info_text.strip(): - step["observations"] = [f"检测到return_information{format_type}调用,返回信息"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..." - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{parsed_info_text}", - ) - - return True, parsed_info_text, thinking_steps, False - else: - # 信息为空,直接退出查询 - step["observations"] = [f"检测到return_information{format_type}调用,信息为空"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}判断信息为空" - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:通过return_information文本格式判断信息为空", - ) - - return False, "", thinking_steps, False - - # 如果没有检测到return_information格式,记录思考过程,继续下一轮迭代 - step["observations"] = [f"思考完成,但未调用工具。响应: {response}"] - logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}") - collected_info += f"思考: {response}" - else: - logger.warning(f"{react_log_prefix}第 {iteration + 1} 次迭代 无工具调用且无响应") - step["observations"] = ["无响应且无工具调用"] - thinking_steps.append(step) - iteration += 1 # 在continue之前增加迭代计数,避免跳过iteration += 1 - continue - - # 处理工具调用 - # 首先检查是否有return_information工具调用,如果有则立即返回,不再处理其他工具 - return_information_info = None - for tool_call in tool_calls: - tool_name = tool_call.func_name - tool_args = tool_call.args or {} - - if tool_name == "return_information": - return_information_info = tool_args.get("information", "") - - # 返回信息(即使为空也返回) - step["actions"].append( - { - "action_type": "return_information", - "action_params": {"information": return_information_info}, - } - ) - if return_information_info and return_information_info.strip(): - # 有信息,返回 - step["observations"] = ["检测到return_information工具调用,返回信息"] - thinking_steps.append(step) - logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具返回信息: {return_information_info}" - ) - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{return_information_info}", - ) - - return True, return_information_info, thinking_steps, False - else: - # 信息为空,直接退出查询 - step["observations"] = ["检测到return_information工具调用,信息为空"] - thinking_steps.append(step) - logger.info(f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information工具判断信息为空") - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:通过return_information工具判断信息为空", - ) - - return False, "", thinking_steps, False - - # 如果没有return_information工具调用,继续处理其他工具 - tool_tasks = [] - for i, tool_call in enumerate(tool_calls): - tool_name = tool_call.func_name - tool_args = tool_call.args or {} - - logger.debug( - f"{react_log_prefix}第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})" - ) - - # 跳过return_information工具调用(已经在上面处理过了) - if tool_name == "return_information": - continue - - # 记录最后一次使用的工具名称(用于判断是否需要额外迭代) - last_tool_name = tool_name - - # 普通工具调用 - tool = tool_registry.get_tool(tool_name) - if tool: - # 准备工具参数(需要添加chat_id如果工具需要) - import inspect - - sig = inspect.signature(tool.execute_func) - tool_params = tool_args.copy() - if "chat_id" in sig.parameters: - tool_params["chat_id"] = chat_id - - # 创建异步任务 - async def execute_single_tool(tool_instance, params, tool_name_str, iter_num): - try: - observation = await tool_instance.execute(**params) - param_str = ", ".join([f"{k}={v}" for k, v in params.items() if k != "chat_id"]) - return f"查询{tool_name_str}({param_str})的结果:{observation}" - except Exception as e: - error_msg = f"工具执行失败: {str(e)}" - logger.error(f"{react_log_prefix}第 {iter_num + 1} 次迭代 工具 {tool_name_str} {error_msg}") - return f"查询{tool_name_str}失败: {error_msg}" - - tool_tasks.append(execute_single_tool(tool, tool_params, tool_name, iteration)) - step["actions"].append({"action_type": tool_name, "action_params": tool_args}) - else: - error_msg = f"未知的工具类型: {tool_name}" - logger.warning( - f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1}/{len(tool_calls)} {error_msg}" - ) - tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{tool_name}失败: {error_msg}"))) - - # 并行执行所有工具 - if tool_tasks: - observations = await asyncio.gather(*tool_tasks, return_exceptions=True) - - # 处理执行结果 - for i, (tool_call_item, observation) in enumerate(zip(tool_calls, observations, strict=False)): - if isinstance(observation, Exception): - observation = f"工具执行异常: {str(observation)}" - logger.error(f"{react_log_prefix}第 {iteration + 1} 次迭代 工具 {i + 1} 执行异常: {observation}") - - observation_text = observation if isinstance(observation, str) else str(observation) - stripped_observation = observation_text.strip() - step["observations"].append(observation_text) - collected_info += f"\n{observation_text}\n" - if stripped_observation: - # 不再自动检测工具输出中的jargon,改为通过 query_words 工具主动查询 - tool_builder = MessageBuilder() - tool_builder.set_role(RoleType.Tool) - tool_builder.add_text_content(observation_text) - tool_builder.add_tool_call(tool_call_item.call_id) - conversation_messages.append(tool_builder.build()) - - thinking_steps.append(step) - - # 检查是否需要额外迭代:如果最后一次使用的工具是 search_chat_history 且达到最大迭代次数,额外增加一回合 - if iteration + 1 >= max_iterations and last_tool_name == "search_chat_history" and not is_timeout: - max_iterations_with_extra = max_iterations + 1 - logger.info( - f"{react_log_prefix}达到最大迭代次数(已迭代{iteration + 1}次),最后一次使用工具为 search_chat_history,额外增加一回合尝试" - ) - - iteration += 1 - - # 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估 - # 最终评估单独处理,不算在迭代中 - should_do_final_evaluation = False - if is_timeout: - should_do_final_evaluation = True - logger.warning(f"{react_log_prefix}超时,已迭代{iteration}次,进入最终评估") - elif iteration >= max_iterations: - should_do_final_evaluation = True - logger.info(f"{react_log_prefix}达到最大迭代次数(已迭代{iteration}次),进入最终评估") - - if should_do_final_evaluation: - # 获取必要变量用于最终评估 - tool_registry = get_tool_registry() - bot_name = global_config.bot.nickname - time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - current_iteration = iteration + 1 - remaining_iterations = 0 - - # 提取函数调用中参数的值,支持单引号和双引号 - def extract_quoted_content(text, func_name, param_name): - """从文本中提取函数调用中参数的值,支持单引号和双引号 - - Args: - text: 要搜索的文本 - func_name: 函数名,如 'return_information' - param_name: 参数名,如 'information' - - Returns: - 提取的参数值,如果未找到则返回None - """ - if not text: - return None - - # 查找函数调用位置(不区分大小写) - func_pattern = func_name.lower() - text_lower = text.lower() - func_pos = text_lower.find(func_pattern) - if func_pos == -1: - return None - - # 查找参数名和等号 - param_pattern = f"{param_name}=" - param_pos = text_lower.find(param_pattern, func_pos) - if param_pos == -1: - return None - - # 跳过参数名、等号和空白 - start_pos = param_pos + len(param_pattern) - while start_pos < len(text) and text[start_pos] in " \t\n": - start_pos += 1 - - if start_pos >= len(text): - return None - - # 确定引号类型 - quote_char = text[start_pos] - if quote_char not in ['"', "'"]: - return None - - # 查找匹配的结束引号(考虑转义) - end_pos = start_pos + 1 - while end_pos < len(text): - if text[end_pos] == quote_char: - # 检查是否是转义的引号 - if end_pos > start_pos + 1 and text[end_pos - 1] == "\\": - end_pos += 1 - continue - # 找到匹配的引号 - content = text[start_pos + 1 : end_pos] - # 处理转义字符 - content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\") - return content - end_pos += 1 - - return None - - # 执行最终评估 - evaluation_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_final") - evaluation_prompt_template.add_context("bot_name", bot_name) - evaluation_prompt_template.add_context("time_now", time_now) - evaluation_prompt_template.add_context("chat_history", chat_history) - evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息") - evaluation_prompt_template.add_context("current_iteration", str(current_iteration)) - evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) - evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) - evaluation_prompt = await prompt_manager.render_prompt(evaluation_prompt_template) - - evaluation_result = await llm_api.generate( - llm_api.LLMServiceRequest( - task_name="utils", - request_type="memory.react.final", - prompt=evaluation_prompt, - tool_options=[], - ) - ) - eval_success = evaluation_result.success - eval_response = evaluation_result.completion.response - - if not eval_success: - logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:最终评估阶段LLM调用失败", - ) - return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout - - if global_config.debug.show_memory_prompt: - logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}") - logger.info(f"{react_log_prefix}最终评估响应: {eval_response}") - - # 从最终评估响应中提取return_information - return_information_content = None - - if eval_response: - return_information_content = extract_quoted_content(eval_response, "return_information", "information") - - # 如果提取到信息,返回(无论是否超时,都视为成功完成) - if return_information_content is not None: - eval_step = { - "iteration": current_iteration, - "thought": f"[最终评估] {eval_response}", - "actions": [ - {"action_type": "return_information", "action_params": {"information": return_information_content}} - ], - "observations": ["最终评估阶段检测到return_information"], - } - thinking_steps.append(eval_step) - if return_information_content and return_information_content.strip(): - logger.info(f"ReAct Agent 最终评估阶段返回信息: {return_information_content}") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status=f"返回信息:{return_information_content}", - ) - return True, return_information_content, thinking_steps, False - else: - logger.info("ReAct Agent 最终评估阶段判断信息为空") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="信息为空:最终评估阶段判断信息为空", - ) - return False, "", thinking_steps, False - - # 如果没有明确判断,视为not_enough_info,返回空字符串(不返回任何信息) - eval_step = { - "iteration": current_iteration, - "thought": f"[最终评估] {eval_response}", - "actions": [{"action_type": "return_information", "action_params": {"information": ""}}], - "observations": ["已到达最大迭代次数,信息为空"], - } - thinking_steps.append(eval_step) - logger.info("ReAct Agent 已到达最大迭代次数,信息为空") - - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:已到达最大迭代次数,无法找到答案", - ) - - return False, "", thinking_steps, is_timeout - - # 如果正常迭代过程中提前找到答案返回,不会到达这里 - # 如果正常迭代结束但没有触发最终评估(理论上不应该发生),直接返回 - logger.warning("ReAct Agent正常迭代结束,但未触发最终评估") - _log_conversation_messages( - conversation_messages, - head_prompt=first_head_prompt, - final_status="未找到答案:正常迭代结束", - ) - - return False, "", thinking_steps, is_timeout - - -def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) -> str: - """获取最近一段时间内的查询历史(用于避免重复查询) - - Args: - chat_id: 聊天ID - time_window_seconds: 时间窗口(秒),默认10分钟 - - Returns: - str: 格式化的查询历史字符串 - """ - try: - _current_time = time.time() - - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(5) - ) - records = session.exec(statement).all() - - if not records: - return "" - - history_lines = ["最近已查询的问题和结果:"] - - for record in records: - status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" - answer_preview = "" - # 只有找到答案时才显示答案内容 - if record.found_answer and record.answer: - # 截取答案前100字符 - answer_preview = record.answer[:100] - if len(record.answer) > 100: - answer_preview += "..." - - history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"]) - if answer_preview: - history_lines.append(f" 答案:{answer_preview}") - history_lines.append("") # 空行分隔 - - return "\n".join(history_lines) - - except Exception as e: - logger.error(f"获取查询历史失败: {e}") - return "" - - -def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) -> List[str]: - """获取最近一段时间内已找到答案的查询记录(用于返回给 replyer) - - Args: - chat_id: 聊天ID - time_window_seconds: 时间窗口(秒),默认10分钟 - - Returns: - List[str]: 格式化的答案列表,每个元素格式为 "问题:xxx\n答案:xxx" - """ - try: - _current_time = time.time() - - # 查询最近时间窗口内已找到答案的记录,按更新时间倒序 - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .where(col(ThinkingQuestion.found_answer)) - .where(col(ThinkingQuestion.answer).is_not(None)) - .where(col(ThinkingQuestion.answer) != "") - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(3) - ) - records = session.exec(statement).all() - - if not records: - return [] - - return [f"问题:{record.question}\n答案:{record.answer}" for record in records if record.answer] - - except Exception as e: - logger.error(f"获取最近已找到答案的记录失败: {e}") - return [] - - -def _store_thinking_back( - chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]] -) -> None: - """存储或更新思考过程到数据库(如果已存在则更新,否则创建) - - Args: - chat_id: 聊天ID - question: 问题 - context: 上下文信息 - found_answer: 是否找到答案 - answer: 答案内容 - thinking_steps: 思考步骤列表 - """ - try: - now = time.time() - - # 先查询是否已存在相同chat_id和问题的记录 - with get_db_session() as session: - statement = ( - select(ThinkingQuestion) - .where(col(ThinkingQuestion.context) == chat_id) - .where(col(ThinkingQuestion.question) == question) - .order_by(col(ThinkingQuestion.updated_timestamp).desc()) - .limit(1) - ) - if record := session.exec(statement).first(): - record.context = context - record.found_answer = found_answer - record.answer = answer - record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False) - record.updated_timestamp = datetime.fromtimestamp(now) - session.add(record) - logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...") - return - - new_record = ThinkingQuestion( - question=question, - context=chat_id, - found_answer=found_answer, - answer=answer, - thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), - created_timestamp=datetime.fromtimestamp(now), - updated_timestamp=datetime.fromtimestamp(now), - ) - session.add(new_record) - except Exception as e: - logger.error(f"存储思考过程失败: {e}") - - -async def _process_memory_retrieval( - chat_id: str, - context: str, - initial_info: str = "", - max_iterations: Optional[int] = None, - chat_history: str = "", -) -> Optional[str]: - """处理记忆检索 - - Args: - chat_id: 聊天ID - context: 上下文信息 - initial_info: 初始信息,将传递给ReAct Agent - max_iterations: 最大迭代次数 - chat_history: 聊天记录,将传递给 ReAct Agent - - Returns: - Optional[str]: 如果找到答案,返回答案内容,否则返回None - """ - _cleanup_stale_not_found_thinking_back() - - question_initial_info = initial_info or "" - - # 直接使用ReAct Agent进行记忆检索 - # 如果未指定max_iterations,使用配置的默认值 - if max_iterations is None: - max_iterations = global_config.memory.max_agent_iterations - - found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( - chat_id=chat_id, - max_iterations=max_iterations, - timeout=global_config.memory.agent_timeout_seconds, - initial_info=question_initial_info, - chat_history=chat_history, - ) - - # 不再存储到数据库,直接返回答案 - if is_timeout: - logger.info("ReAct Agent超时,不返回结果") - - return answer if found_answer and answer else None - - -async def build_memory_retrieval_prompt( - message: str, - sender: str, - target: str, - chat_stream, - think_level: int = 1, - unknown_words: Optional[List[str]] = None, -) -> str: - """构建记忆检索提示 - Args: - message: 聊天历史记录 - sender: 发送者名称 - target: 目标消息内容 - chat_stream: 聊天流对象 - think_level: 思考深度等级 - unknown_words: Planner 提供的未知词语列表,优先使用此列表而不是从聊天记录匹配 - - Returns: - str: 记忆检索结果字符串 - """ - start_time = time.time() - - # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流(优先群名称/用户昵称) - try: - group_info = chat_stream.group_info - user_info = chat_stream.user_info - # 群聊优先使用群名称 - if group_info is not None and getattr(group_info, "group_name", None): - stream_name = group_info.group_name.strip() or str(group_info.group_id) - # 私聊使用用户昵称 - elif user_info is not None and getattr(user_info, "user_nickname", None): - stream_name = user_info.user_nickname.strip() or str(user_info.user_id) - # 兜底使用 stream_id - else: - stream_name = chat_stream.stream_id - except Exception: - stream_name = chat_stream.stream_id - log_prefix = f"[{stream_name}] " if stream_name else "" - - logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}") - try: - chat_id = chat_stream.stream_id - - # 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供) - initial_info = "" - if unknown_words and len(unknown_words) > 0: - # 清理和去重 unknown_words - cleaned_concepts = [] - for word in unknown_words: - if isinstance(word, str): - if cleaned := word.strip(): - cleaned_concepts.append(cleaned) - if cleaned_concepts: - # 对匹配到的概念进行jargon检索,作为初始信息 - concept_info = await retrieve_concepts_with_jargon(cleaned_concepts, chat_id) - if concept_info: - initial_info += concept_info - logger.info( - f"{log_prefix}使用 Planner 提供的 unknown_words,共 {len(cleaned_concepts)} 个概念,检索结果: {concept_info[:100]}..." - ) - else: - logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果") - - # 直接使用 ReAct Agent 进行记忆检索(跳过问题生成步骤) - base_max_iterations = global_config.memory.max_agent_iterations - # 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半 - if think_level == 0: - max_iterations = max(1, base_max_iterations // 2) # 至少为1 - else: - max_iterations = base_max_iterations - timeout_seconds = global_config.memory.agent_timeout_seconds - logger.debug( - f"{log_prefix}直接使用 ReAct Agent 进行记忆检索,think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒" - ) - - # 直接调用 ReAct Agent 处理记忆检索 - try: - result = await _process_memory_retrieval( - chat_id=chat_id, - context=message, - initial_info=initial_info, - max_iterations=max_iterations, - chat_history=message, - ) - except Exception as e: - logger.error(f"{log_prefix}处理记忆检索时发生异常: {e}") - result = None - - end_time = time.time() - - if result: - logger.info(f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}秒") - return f"你回忆起了以下信息:\n{result}\n如果与回复内容相关,可以参考这些回忆的信息。\n" - else: - logger.debug(f"{log_prefix}记忆检索未找到相关信息") - return "" - - except Exception as e: - logger.error(f"{log_prefix}记忆检索时发生异常: {str(e)}") - return "" diff --git a/src/memory_system/memory_utils.py b/src/memory_system/memory_utils.py deleted file mode 100644 index 9886142c..00000000 --- a/src/memory_system/memory_utils.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -""" -记忆系统工具函数 -包含模糊查找、相似度计算等工具函数 -""" - -import json -import re -from datetime import datetime -from typing import Tuple -from typing import List -from json_repair import repair_json - -from src.common.logger import get_logger - - -logger = get_logger("memory_utils") - - -def parse_questions_json(response: str) -> Tuple[List[str], List[str]]: - """解析问题JSON,返回概念列表和问题列表 - - Args: - response: LLM返回的响应 - - Returns: - Tuple[List[str], List[str]]: (概念列表, 问题列表) - """ - try: - # 尝试提取JSON(可能包含在```json代码块中) - json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - - if matches: - json_str = matches[0] - else: - # 尝试直接解析整个响应 - json_str = response.strip() - - # 修复可能的JSON错误 - repaired_json = repair_json(json_str) - - # 解析JSON - parsed = json.loads(repaired_json) - - # 只支持新格式:包含concepts和questions的对象 - if not isinstance(parsed, dict): - logger.warning(f"解析的JSON不是对象格式: {parsed}") - return [], [] - - concepts_raw = parsed.get("concepts", []) - questions_raw = parsed.get("questions", []) - - # 确保是列表 - if not isinstance(concepts_raw, list): - concepts_raw = [] - if not isinstance(questions_raw, list): - questions_raw = [] - - # 确保所有元素都是字符串 - concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()] - questions = [q for q in questions_raw if isinstance(q, str) and q.strip()] - - return concepts, questions - - except Exception as e: - logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...") - return [], [] - - -def parse_datetime_to_timestamp(value: str) -> float: - """ - 接受多种常见格式并转换为时间戳(秒) - 支持示例: - - 2025-09-29 - - 2025-09-29 00:00:00 - - 2025/09/29 00:00 - - 2025-09-29T00:00:00 - """ - value = value.strip() - fmts = [ - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d", - "%Y/%m/%d", - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%dT%H:%M", - ] - last_err = None - for fmt in fmts: - try: - dt = datetime.strptime(value, fmt) - return dt.timestamp() - except Exception as e: - last_err = e - raise ValueError(f"无法解析时间: {value} ({last_err})") diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py deleted file mode 100644 index 9f2673b2..00000000 --- a/src/memory_system/retrieval_tools/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -记忆检索工具模块 -提供统一的工具注册和管理系统 -""" - -from .tool_registry import ( - MemoryRetrievalTool, - MemoryRetrievalToolRegistry, - register_memory_retrieval_tool, - get_tool_registry, -) - -# 导入所有工具的注册函数 -from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge -from .query_words import register_tool as register_query_words -from .return_information import register_tool as register_return_information -from src.config.config import global_config - - -def init_all_tools(): - """初始化并注册所有记忆检索工具""" - register_query_words() - register_return_information() - - # LPMM知识库检索工具 - if global_config.lpmm_knowledge.lpmm_mode == "agent": - register_lpmm_knowledge() - - -__all__ = [ - "MemoryRetrievalTool", - "MemoryRetrievalToolRegistry", - "register_memory_retrieval_tool", - "get_tool_registry", - "init_all_tools", -] diff --git a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py b/src/memory_system/retrieval_tools/query_lpmm_knowledge.py deleted file mode 100644 index eed01af1..00000000 --- a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -通过LPMM知识库查询信息 - 工具实现 -""" - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import get_qa_manager -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def query_lpmm_knowledge(query: str, limit: int = 5) -> str: - """在LPMM知识库中查询相关信息 - - Args: - query: 查询关键词 - - Returns: - str: 查询结果 - """ - try: - content = str(query).strip() - if not content: - return "查询关键词为空" - - try: - limit_value = int(limit) - except (TypeError, ValueError): - limit_value = 5 - limit_value = max(1, limit_value) - - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用") - return "LPMM知识库未启用" - - qa_manager = get_qa_manager() - if qa_manager is None: - logger.debug("LPMM知识库未初始化,跳过查询") - return "LPMM知识库未初始化" - - knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value) - logger.debug(f"LPMM知识库查询结果: {knowledge_info}") - - if knowledge_info: - return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}" - - return f"在LPMM知识库中未找到与“{content}”相关的信息" - - except Exception as e: - logger.error(f"LPMM知识库查询失败: {e}") - return f"LPMM知识库查询失败:{str(e)}" - - -def register_tool(): - """注册LPMM知识库查询工具""" - register_memory_retrieval_tool( - name="lpmm_search_knowledge", - description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索", - parameters=[ - { - "name": "query", - "type": "string", - "description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?", - "required": True, - }, - { - "name": "limit", - "type": "integer", - "description": "希望返回的相关知识条数,默认为5", - "required": False, - }, - ], - execute_func=query_lpmm_knowledge, - ) diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py deleted file mode 100644 index ee28b934..00000000 --- a/src/memory_system/retrieval_tools/query_words.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -查询黑话/概念含义 - 工具实现 -用于在记忆检索过程中主动查询未知词语或黑话的含义 -""" - -from src.common.logger import get_logger -from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def query_words(chat_id: str, words: str) -> str: - """查询词语或黑话的含义 - - Args: - chat_id: 聊天ID - words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔) - - Returns: - str: 查询结果,包含词语的含义解释 - """ - try: - if not words or not words.strip(): - return "未提供要查询的词语" - - # 解析词语列表(支持逗号、空格等分隔符) - words_list = [] - for separator in [",", ",", " ", "\n", "\t"]: - if separator in words: - words_list = [w.strip() for w in words.split(separator) if w.strip()] - break - - # 如果没有找到分隔符,整个字符串作为一个词语 - if not words_list: - words_list = [words.strip()] - - # 去重 - unique_words = [] - seen = set() - for word in words_list: - if word and word not in seen: - unique_words.append(word) - seen.add(word) - - if not unique_words: - return "未提供有效的词语" - - logger.info(f"查询词语含义: {unique_words}") - - # 调用检索函数 - result = await retrieve_concepts_with_jargon(unique_words, chat_id) - - if result: - return result - else: - return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释" - - except Exception as e: - logger.error(f"查询词语含义失败: {e}") - return f"查询失败: {str(e)}" - - -def register_tool(): - """注册工具""" - register_memory_retrieval_tool( - name="query_words", - description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。", - parameters=[ - { - "name": "words", - "type": "string", - "description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')", - "required": True, - }, - ], - execute_func=query_words, - ) diff --git a/src/memory_system/retrieval_tools/return_information.py b/src/memory_system/retrieval_tools/return_information.py deleted file mode 100644 index bf368083..00000000 --- a/src/memory_system/retrieval_tools/return_information.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询 -""" - -from src.common.logger import get_logger -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def return_information(information: str) -> str: - """返回总结信息并结束查询 - - Args: - information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。 - - Returns: - str: 确认信息 - """ - if information and information.strip(): - logger.info(f"返回总结信息: {information}") - return f"已确认返回信息: {information}" - else: - logger.info("未收集到相关信息,结束查询") - return "未收集到相关信息,查询结束" - - -def register_tool(): - """注册return_information工具""" - register_memory_retrieval_tool( - name="return_information", - description="当你决定结束查询时,调用此工具。基于已收集的信息,总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助,在information参数中提供总结信息;如果信息无关或没有帮助,可以提供空字符串。", - parameters=[ - { - "name": "information", - "type": "string", - "description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。", - "required": True, - }, - ], - execute_func=return_information, - ) diff --git a/src/memory_system/retrieval_tools/tool_registry.py b/src/memory_system/retrieval_tools/tool_registry.py deleted file mode 100644 index f2dd1f0d..00000000 --- a/src/memory_system/retrieval_tools/tool_registry.py +++ /dev/null @@ -1,167 +0,0 @@ -"""工具注册系统。 - -提供统一的工具注册和管理接口。 -""" - -from typing import Any, Awaitable, Callable, Dict, List, Optional - -from src.common.logger import get_logger -from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option - -logger = get_logger("memory_retrieval_tools") - - -class MemoryRetrievalTool: - """记忆检索工具基类""" - - def __init__( - self, - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]], - ) -> None: - """初始化工具。 - - Args: - name: 工具名称。 - description: 工具描述。 - parameters: 参数定义列表。 - execute_func: 执行函数,必须是异步函数。 - """ - self.name = name - self.description = description - self.parameters = parameters - self.execute_func = execute_func - - def get_tool_description(self) -> str: - """获取工具的文本描述,用于prompt""" - param_descriptions = [] - for param in self.parameters: - param_name = param.get("name", "") - param_type = param.get("type", "string") - param_desc = param.get("description", "") - required = param.get("required", True) - required_str = "必填" if required else "可选" - param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}") - - params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数" - return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}" - - async def execute(self, **kwargs: Any) -> str: - """执行工具。""" - return await self.execute_func(**kwargs) - - def get_tool_definition(self) -> Dict[str, Any]: - """获取规范化的工具定义。 - - Returns: - Dict[str, Any]: 统一工具定义字典。 - """ - legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = [] - - for param in self.parameters: - param_name = param.get("name", "") - param_type_str = param.get("type", "string").lower() - param_desc = param.get("description", "") - is_required = param.get("required", False) - enum_values = param.get("enum", None) - - # 转换类型字符串到ToolParamType - type_mapping = { - "string": ToolParamType.STRING, - "integer": ToolParamType.INTEGER, - "int": ToolParamType.INTEGER, - "float": ToolParamType.FLOAT, - "boolean": ToolParamType.BOOLEAN, - "bool": ToolParamType.BOOLEAN, - } - param_type = type_mapping.get(param_type_str, ToolParamType.STRING) - - legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values)) - - normalized_option = normalize_tool_option( - { - "name": self.name, - "description": self.description, - "parameters": legacy_parameters, - } - ) - return { - "name": normalized_option.name, - "description": normalized_option.description, - "parameters_schema": normalized_option.parameters_schema, - } - - -class MemoryRetrievalToolRegistry: - """工具注册器""" - - def __init__(self) -> None: - """初始化工具注册器。""" - self.tools: Dict[str, MemoryRetrievalTool] = {} - - def register_tool(self, tool: MemoryRetrievalTool) -> None: - """注册工具""" - if tool.name in self.tools: - logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册") - return - self.tools[tool.name] = tool - logger.info(f"注册记忆检索工具: {tool.name}") - - def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]: - """获取工具""" - return self.tools.get(name) - - def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]: - """获取所有工具""" - return self.tools.copy() - - def get_tools_description(self) -> str: - """获取所有工具的描述,用于prompt""" - descriptions = [] - for i, tool in enumerate(self.tools.values(), 1): - descriptions.append(f"{i}. {tool.get_tool_description()}") - return "\n".join(descriptions) - - def get_action_types_list(self) -> str: - """获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)""" - action_types = [tool.name for tool in self.tools.values()] - action_types.append("final_answer") - action_types.append("no_answer") - return " 或 ".join([f'"{at}"' for at in action_types]) - - def get_tool_definitions(self) -> List[Dict[str, Any]]: - """获取所有工具的定义列表,用于LLM function calling - - Returns: - List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典 - """ - return [tool.get_tool_definition() for tool in self.tools.values()] - - -# 全局工具注册器实例 -_tool_registry = MemoryRetrievalToolRegistry() - - -def register_memory_retrieval_tool( - name: str, - description: str, - parameters: List[Dict[str, Any]], - execute_func: Callable[..., Awaitable[str]], -) -> None: - """注册记忆检索工具的便捷函数。 - - Args: - name: 工具名称。 - description: 工具描述。 - parameters: 参数定义列表。 - execute_func: 执行函数。 - """ - tool = MemoryRetrievalTool(name, description, parameters, execute_func) - _tool_registry.register_tool(tool) - - -def get_tool_registry() -> MemoryRetrievalToolRegistry: - """获取工具注册器实例""" - return _tool_registry