fix:优化记忆提取,修复破损的tool信息

This commit is contained in:
SengokuCola
2025-05-27 18:21:05 +08:00
parent 548a583cc7
commit 52f7cc3762
9 changed files with 110 additions and 50 deletions

View File

@@ -4,24 +4,58 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger_manager import get_logger
from src.chat.utils.prompt_builder import Prompt
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from datetime import datetime
from src.chat.memory_system.Hippocampus import HippocampusManager
from typing import List, Dict
import difflib
import json
from json_repair import repair_json
logger = get_logger("memory_activator")
def get_keywords_from_json(json_str):
"""
从JSON字符串中提取关键词列表
Args:
json_str: JSON格式的字符串
Returns:
List[str]: 关键词列表
"""
try:
# 使用repair_json修复JSON格式
fixed_json = repair_json(json_str)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json, str):
result = json.loads(fixed_json)
else:
# 如果repair_json直接返回了字典对象直接使用
result = fixed_json
# 提取关键词
keywords = result.get("keywords", [])
return keywords
except Exception as e:
logger.error(f"解析关键词JSON失败: {e}")
return []
def init_prompt():
# --- Group Chat Prompt ---
memory_activator_prompt = """
你是一个记忆分析器,你需要根据以下信息来进行会议
你是一个记忆分析器,你需要根据以下信息来进行回忆
以下是一场聊天中的信息,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
{obs_info_text}
历史关键词(请避免重复提取这些关键词):
{cached_keywords}
请输出一个json格式包含以下字段
{{
"keywords": ["关键词1", "关键词2", "关键词3",......]
@@ -39,6 +73,7 @@ class MemoryActivator:
model=global_config.model.memory_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
)
self.running_memory = []
self.cached_keywords = set() # 用于缓存历史关键词
async def activate_memory(self, observations) -> List[Dict]:
"""
@@ -61,31 +96,47 @@ class MemoryActivator:
elif isinstance(observation, HFCloopObservation):
obs_info_text += observation.get_observe_info()
logger.debug(f"回忆待检索内容obs_info_text: {obs_info_text}")
# logger.debug(f"回忆待检索内容obs_info_text: {obs_info_text}")
# prompt = await global_prompt_manager.format_prompt(
# "memory_activator_prompt",
# obs_info_text=obs_info_text,
# )
# logger.debug(f"prompt: {prompt}")
# response = await self.summary_model.generate_response(prompt)
# logger.debug(f"response: {response}")
# # 只取response的第一个元素字符串
# response_str = response[0]
# keywords = list(get_keywords_from_json(response_str))
# #调用记忆系统获取相关记忆
# related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
# )
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=True
# 将缓存的关键词转换为字符串用于prompt
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
prompt = await global_prompt_manager.format_prompt(
"memory_activator_prompt",
obs_info_text=obs_info_text,
cached_keywords=cached_keywords_str,
)
logger.debug(f"prompt: {prompt}")
response = await self.summary_model.generate_response(prompt)
logger.debug(f"response: {response}")
# 只取response的第一个元素字符串
response_str = response[0]
keywords = list(get_keywords_from_json(response_str))
# 更新关键词缓存
if keywords:
# 限制缓存大小最多保留10个关键词
if len(self.cached_keywords) > 10:
# 转换为列表,移除最早的关键词
cached_list = list(self.cached_keywords)
self.cached_keywords = set(cached_list[-8:])
# 添加新的关键词到缓存
self.cached_keywords.update(keywords)
logger.debug(f"更新关键词缓存: {self.cached_keywords}")
#调用记忆系统获取相关记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
)
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
# text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=False
# )
# logger.debug(f"获取到的记忆: {related_memory}")
# 激活时所有已有记忆的duration+1达到3则移除