414 lines
13 KiB
Python
414 lines
13 KiB
Python
from typing import Dict, TypeVar, List, Optional
|
||
import traceback
|
||
from json_repair import repair_json
|
||
from rich.traceback import install
|
||
from src.common.logger import get_logger
|
||
from src.llm_models.utils_model import LLMRequest
|
||
from src.config.config import global_config
|
||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||
import json # 添加json模块导入
|
||
|
||
|
||
install(extra_lines=3)
|
||
logger = get_logger("working_memory")
|
||
|
||
T = TypeVar("T")
|
||
|
||
|
||
class MemoryManager:
|
||
def __init__(self, chat_id: str):
|
||
"""
|
||
初始化工作记忆
|
||
|
||
Args:
|
||
chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
|
||
"""
|
||
# 关联的聊天ID
|
||
self._chat_id = chat_id
|
||
|
||
# 记忆项列表
|
||
self._memories: List[MemoryItem] = []
|
||
|
||
# ID到记忆项的映射
|
||
self._id_map: Dict[str, MemoryItem] = {}
|
||
|
||
self.llm_summarizer = LLMRequest(
|
||
model=global_config.model.memory,
|
||
temperature=0.3,
|
||
request_type="working_memory",
|
||
)
|
||
|
||
@property
|
||
def chat_id(self) -> str:
|
||
"""获取关联的聊天ID"""
|
||
return self._chat_id
|
||
|
||
@chat_id.setter
|
||
def chat_id(self, value: str):
|
||
"""设置关联的聊天ID"""
|
||
self._chat_id = value
|
||
|
||
def push_item(self, memory_item: MemoryItem) -> str:
|
||
"""
|
||
推送一个已创建的记忆项到工作记忆中
|
||
|
||
Args:
|
||
memory_item: 要存储的记忆项
|
||
|
||
Returns:
|
||
记忆项的ID
|
||
"""
|
||
# 添加到内存和ID映射
|
||
self._memories.append(memory_item)
|
||
self._id_map[memory_item.id] = memory_item
|
||
|
||
return memory_item.id
|
||
|
||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||
"""
|
||
通过ID获取记忆项
|
||
|
||
Args:
|
||
memory_id: 记忆项ID
|
||
|
||
Returns:
|
||
找到的记忆项,如果不存在则返回None
|
||
"""
|
||
memory_item = self._id_map.get(memory_id)
|
||
if memory_item:
|
||
# 检查记忆强度,如果小于1则删除
|
||
if not memory_item.is_memory_valid():
|
||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||
self.delete(memory_id)
|
||
return None
|
||
|
||
return memory_item
|
||
|
||
def get_all_items(self) -> List[MemoryItem]:
|
||
"""获取所有记忆项"""
|
||
return list(self._id_map.values())
|
||
|
||
def find_items(
|
||
self,
|
||
source: Optional[str] = None,
|
||
start_time: Optional[float] = None,
|
||
end_time: Optional[float] = None,
|
||
memory_id: Optional[str] = None,
|
||
limit: Optional[int] = None,
|
||
newest_first: bool = False,
|
||
min_strength: float = 0.0,
|
||
) -> List[MemoryItem]:
|
||
"""
|
||
按条件查找记忆项
|
||
|
||
Args:
|
||
source: 数据来源
|
||
start_time: 开始时间戳
|
||
end_time: 结束时间戳
|
||
memory_id: 特定记忆项ID
|
||
limit: 返回结果的最大数量
|
||
newest_first: 是否按最新优先排序
|
||
min_strength: 最小记忆强度
|
||
|
||
Returns:
|
||
符合条件的记忆项列表
|
||
"""
|
||
# 如果提供了特定ID,直接查找
|
||
if memory_id:
|
||
item = self.get_by_id(memory_id)
|
||
return [item] if item else []
|
||
|
||
results = []
|
||
|
||
# 获取所有项目
|
||
items = self._memories
|
||
|
||
# 如果需要最新优先,则反转遍历顺序
|
||
if newest_first:
|
||
items_to_check = list(reversed(items))
|
||
else:
|
||
items_to_check = items
|
||
|
||
# 遍历项目
|
||
for item in items_to_check:
|
||
# 检查来源是否匹配
|
||
if source is not None and not item.matches_source(source):
|
||
continue
|
||
|
||
# 检查时间范围
|
||
if start_time is not None and item.timestamp < start_time:
|
||
continue
|
||
if end_time is not None and item.timestamp > end_time:
|
||
continue
|
||
|
||
# 检查记忆强度
|
||
if min_strength > 0 and item.memory_strength < min_strength:
|
||
continue
|
||
|
||
# 所有条件都满足,添加到结果中
|
||
results.append(item)
|
||
|
||
# 如果达到限制数量,提前返回
|
||
if limit is not None and len(results) >= limit:
|
||
return results
|
||
|
||
return results
|
||
|
||
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
|
||
"""
|
||
使用LLM总结记忆项
|
||
|
||
Args:
|
||
content: 需要总结的内容
|
||
|
||
Returns:
|
||
包含brief和summary的字典
|
||
"""
|
||
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
||
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
||
2. 记忆内容概括:对内容进行概括,保留重要信息,200字以内
|
||
|
||
内容:
|
||
{content}
|
||
|
||
请按以下JSON格式输出:
|
||
{{
|
||
"brief": "记忆内容主题",
|
||
"summary": "记忆内容概括"
|
||
}}
|
||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||
"""
|
||
default_summary = {
|
||
"brief": "主题未知的记忆",
|
||
"summary": "无法概括的记忆内容",
|
||
}
|
||
|
||
try:
|
||
# 调用LLM生成总结
|
||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||
|
||
# 使用repair_json解析响应
|
||
try:
|
||
# 使用repair_json修复JSON格式
|
||
fixed_json_string = repair_json(response)
|
||
|
||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||
if isinstance(fixed_json_string, str):
|
||
try:
|
||
json_result = json.loads(fixed_json_string)
|
||
except json.JSONDecodeError as decode_error:
|
||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||
return default_summary
|
||
else:
|
||
# 如果repair_json直接返回了字典对象,直接使用
|
||
json_result = fixed_json_string
|
||
|
||
# 进行额外的类型检查
|
||
if not isinstance(json_result, dict):
|
||
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
|
||
return default_summary
|
||
|
||
# 确保所有必要字段都存在且类型正确
|
||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||
json_result["brief"] = "主题未知的记忆"
|
||
|
||
if "summary" not in json_result or not isinstance(json_result["summary"], str):
|
||
json_result["summary"] = "无法概括的记忆内容"
|
||
|
||
return json_result
|
||
|
||
except Exception as json_error:
|
||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||
return default_summary
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成总结时出错: {str(e)}")
|
||
return default_summary
|
||
|
||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||
"""
|
||
使单个记忆衰减
|
||
|
||
Args:
|
||
memory_id: 记忆ID
|
||
decay_factor: 衰减因子(0-1之间)
|
||
|
||
Returns:
|
||
是否成功衰减
|
||
"""
|
||
memory_item = self.get_by_id(memory_id)
|
||
if not memory_item:
|
||
return False
|
||
|
||
# 计算衰减量(当前强度 * (1-衰减因子))
|
||
old_strength = memory_item.memory_strength
|
||
decay_amount = old_strength * (1 - decay_factor)
|
||
|
||
# 更新强度
|
||
memory_item.memory_strength = decay_amount
|
||
|
||
return True
|
||
|
||
def delete(self, memory_id: str) -> bool:
|
||
"""
|
||
删除指定ID的记忆项
|
||
|
||
Args:
|
||
memory_id: 要删除的记忆项ID
|
||
|
||
Returns:
|
||
是否成功删除
|
||
"""
|
||
if memory_id not in self._id_map:
|
||
return False
|
||
|
||
# 获取要删除的项
|
||
self._id_map[memory_id]
|
||
|
||
# 从内存中删除
|
||
self._memories = [i for i in self._memories if i.id != memory_id]
|
||
|
||
# 从ID映射中删除
|
||
del self._id_map[memory_id]
|
||
|
||
return True
|
||
|
||
def clear(self) -> None:
|
||
"""清除所有记忆"""
|
||
self._memories.clear()
|
||
self._id_map.clear()
|
||
|
||
async def merge_memories(
|
||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||
) -> MemoryItem:
|
||
"""
|
||
合并两个记忆项
|
||
|
||
Args:
|
||
memory_id1: 第一个记忆项ID
|
||
memory_id2: 第二个记忆项ID
|
||
reason: 合并原因
|
||
delete_originals: 是否删除原始记忆,默认为True
|
||
|
||
Returns:
|
||
合并后的记忆项
|
||
"""
|
||
# 获取两个记忆项
|
||
memory_item1 = self.get_by_id(memory_id1)
|
||
memory_item2 = self.get_by_id(memory_id2)
|
||
|
||
if not memory_item1 or not memory_item2:
|
||
raise ValueError("无法找到指定的记忆项")
|
||
|
||
# 构建合并提示
|
||
prompt = f"""
|
||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
||
|
||
合并原因:{reason}
|
||
|
||
记忆1主题:{memory_item1.brief}
|
||
记忆1内容:{memory_item1.summary}
|
||
|
||
记忆2主题:{memory_item2.brief}
|
||
记忆2内容:{memory_item2.summary}
|
||
|
||
请按以下JSON格式输出合并结果:
|
||
{{
|
||
"brief": "合并后的主题(20字以内)",
|
||
"summary": "合并后的内容概括(200字以内)"
|
||
}}
|
||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||
"""
|
||
|
||
# 默认合并结果
|
||
default_merged = {
|
||
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
|
||
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
|
||
}
|
||
|
||
try:
|
||
# 调用LLM合并记忆
|
||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||
|
||
# 处理LLM返回的合并结果
|
||
try:
|
||
# 修复JSON格式
|
||
fixed_json_string = repair_json(response)
|
||
|
||
# 将修复后的字符串解析为Python对象
|
||
if isinstance(fixed_json_string, str):
|
||
try:
|
||
merged_data = json.loads(fixed_json_string)
|
||
except json.JSONDecodeError as decode_error:
|
||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||
merged_data = default_merged
|
||
else:
|
||
# 如果repair_json直接返回了字典对象,直接使用
|
||
merged_data = fixed_json_string
|
||
|
||
# 确保是字典类型
|
||
if not isinstance(merged_data, dict):
|
||
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
|
||
merged_data = default_merged
|
||
|
||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||
merged_data["brief"] = default_merged["brief"]
|
||
|
||
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
|
||
merged_data["summary"] = default_merged["summary"]
|
||
|
||
except Exception as e:
|
||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||
traceback.print_exc()
|
||
merged_data = default_merged
|
||
except Exception as e:
|
||
logger.error(f"合并记忆调用LLM出错: {str(e)}")
|
||
traceback.print_exc()
|
||
merged_data = default_merged
|
||
|
||
# 创建新的记忆项
|
||
# 取两个记忆项中更强的来源
|
||
merged_source = (
|
||
memory_item1.from_source
|
||
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||
else memory_item2.from_source
|
||
)
|
||
|
||
# 创建新的记忆项
|
||
merged_memory = MemoryItem(
|
||
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
|
||
)
|
||
|
||
# 记忆强度取两者最大值
|
||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||
|
||
# 添加到存储中
|
||
self.push_item(merged_memory)
|
||
|
||
# 如果需要,删除原始记忆
|
||
if delete_originals:
|
||
self.delete(memory_id1)
|
||
self.delete(memory_id2)
|
||
|
||
return merged_memory
|
||
|
||
def delete_earliest_memory(self) -> bool:
|
||
"""
|
||
删除最早的记忆项
|
||
|
||
Returns:
|
||
是否成功删除
|
||
"""
|
||
# 获取所有记忆项
|
||
all_memories = self.get_all_items()
|
||
|
||
if not all_memories:
|
||
return False
|
||
|
||
# 按时间戳排序,找到最早的记忆项
|
||
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
|
||
|
||
# 删除最早的记忆项
|
||
return self.delete(earliest_memory.id)
|