remove:冗余的sbhf代码和focus代码
This commit is contained in:
84
src/chat/working_memory/memory_item.py
Normal file
84
src/chat/working_memory/memory_item.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from typing import Tuple
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
||||
|
||||
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
初始化记忆项
|
||||
|
||||
Args:
|
||||
summary: 记忆内容概括
|
||||
from_source: 数据来源
|
||||
brief: 记忆内容主题
|
||||
"""
|
||||
# 生成可读ID:时间戳_随机字符串
|
||||
timestamp = int(time.time())
|
||||
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||
self.id = f"{timestamp}_{random_str}"
|
||||
self.from_source = from_source
|
||||
self.brief = brief
|
||||
self.timestamp = time.time()
|
||||
|
||||
# 记忆内容概括
|
||||
self.summary = summary
|
||||
|
||||
# 记忆精简次数
|
||||
self.compress_count = 0
|
||||
|
||||
# 记忆提取次数
|
||||
self.retrieval_count = 0
|
||||
|
||||
# 记忆强度 (初始为10)
|
||||
self.memory_strength = 10.0
|
||||
|
||||
# 记忆操作历史记录
|
||||
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
|
||||
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
|
||||
|
||||
def matches_source(self, source: str) -> bool:
|
||||
"""检查来源是否匹配"""
|
||||
return self.from_source == source
|
||||
|
||||
def increase_strength(self, amount: float) -> None:
|
||||
"""增加记忆强度"""
|
||||
self.memory_strength = min(10.0, self.memory_strength + amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("strengthen")
|
||||
|
||||
def decrease_strength(self, amount: float) -> None:
|
||||
"""减少记忆强度"""
|
||||
self.memory_strength = max(0.1, self.memory_strength - amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("weaken")
|
||||
|
||||
def increase_compress_count(self) -> None:
|
||||
"""增加精简次数并减弱记忆强度"""
|
||||
self.compress_count += 1
|
||||
# 记录操作历史
|
||||
self.record_operation("compress")
|
||||
|
||||
def record_retrieval(self) -> None:
|
||||
"""记录记忆被提取的情况"""
|
||||
self.retrieval_count += 1
|
||||
# 提取后强度翻倍
|
||||
self.memory_strength = min(10.0, self.memory_strength * 2)
|
||||
# 记录操作历史
|
||||
self.record_operation("retrieval")
|
||||
|
||||
def record_operation(self, operation_type: str) -> None:
|
||||
"""记录操作历史"""
|
||||
current_time = time.time()
|
||||
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
||||
|
||||
def to_tuple(self) -> Tuple[str, str, float, str]:
|
||||
"""转换为元组格式(为了兼容性)"""
|
||||
return (self.summary, self.from_source, self.timestamp, self.id)
|
||||
|
||||
def is_memory_valid(self) -> bool:
|
||||
"""检查记忆是否有效(强度是否大于等于1)"""
|
||||
return self.memory_strength >= 1.0
|
||||
413
src/chat/working_memory/memory_manager.py
Normal file
413
src/chat/working_memory/memory_manager.py
Normal file
@@ -0,0 +1,413 @@
|
||||
from typing import Dict, TypeVar, List, Optional
|
||||
import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
import json # 添加json模块导入
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("working_memory")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化工作记忆
|
||||
|
||||
Args:
|
||||
chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
|
||||
"""
|
||||
# 关联的聊天ID
|
||||
self._chat_id = chat_id
|
||||
|
||||
# 记忆项列表
|
||||
self._memories: List[MemoryItem] = []
|
||||
|
||||
# ID到记忆项的映射
|
||||
self._id_map: Dict[str, MemoryItem] = {}
|
||||
|
||||
self.llm_summarizer = LLMRequest(
|
||||
model=global_config.model.focus_working_memory,
|
||||
temperature=0.3,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
"""获取关联的聊天ID"""
|
||||
return self._chat_id
|
||||
|
||||
@chat_id.setter
|
||||
def chat_id(self, value: str):
|
||||
"""设置关联的聊天ID"""
|
||||
self._chat_id = value
|
||||
|
||||
def push_item(self, memory_item: MemoryItem) -> str:
|
||||
"""
|
||||
推送一个已创建的记忆项到工作记忆中
|
||||
|
||||
Args:
|
||||
memory_item: 要存储的记忆项
|
||||
|
||||
Returns:
|
||||
记忆项的ID
|
||||
"""
|
||||
# 添加到内存和ID映射
|
||||
self._memories.append(memory_item)
|
||||
self._id_map[memory_item.id] = memory_item
|
||||
|
||||
return memory_item.id
|
||||
|
||||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
通过ID获取记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 记忆项ID
|
||||
|
||||
Returns:
|
||||
找到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self._id_map.get(memory_id)
|
||||
if memory_item:
|
||||
# 检查记忆强度,如果小于1则删除
|
||||
if not memory_item.is_memory_valid():
|
||||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||||
self.delete(memory_id)
|
||||
return None
|
||||
|
||||
return memory_item
|
||||
|
||||
def get_all_items(self) -> List[MemoryItem]:
|
||||
"""获取所有记忆项"""
|
||||
return list(self._id_map.values())
|
||||
|
||||
def find_items(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
memory_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
newest_first: bool = False,
|
||||
min_strength: float = 0.0,
|
||||
) -> List[MemoryItem]:
|
||||
"""
|
||||
按条件查找记忆项
|
||||
|
||||
Args:
|
||||
source: 数据来源
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
memory_id: 特定记忆项ID
|
||||
limit: 返回结果的最大数量
|
||||
newest_first: 是否按最新优先排序
|
||||
min_strength: 最小记忆强度
|
||||
|
||||
Returns:
|
||||
符合条件的记忆项列表
|
||||
"""
|
||||
# 如果提供了特定ID,直接查找
|
||||
if memory_id:
|
||||
item = self.get_by_id(memory_id)
|
||||
return [item] if item else []
|
||||
|
||||
results = []
|
||||
|
||||
# 获取所有项目
|
||||
items = self._memories
|
||||
|
||||
# 如果需要最新优先,则反转遍历顺序
|
||||
if newest_first:
|
||||
items_to_check = list(reversed(items))
|
||||
else:
|
||||
items_to_check = items
|
||||
|
||||
# 遍历项目
|
||||
for item in items_to_check:
|
||||
# 检查来源是否匹配
|
||||
if source is not None and not item.matches_source(source):
|
||||
continue
|
||||
|
||||
# 检查时间范围
|
||||
if start_time is not None and item.timestamp < start_time:
|
||||
continue
|
||||
if end_time is not None and item.timestamp > end_time:
|
||||
continue
|
||||
|
||||
# 检查记忆强度
|
||||
if min_strength > 0 and item.memory_strength < min_strength:
|
||||
continue
|
||||
|
||||
# 所有条件都满足,添加到结果中
|
||||
results.append(item)
|
||||
|
||||
# 如果达到限制数量,提前返回
|
||||
if limit is not None and len(results) >= limit:
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
|
||||
"""
|
||||
使用LLM总结记忆项
|
||||
|
||||
Args:
|
||||
content: 需要总结的内容
|
||||
|
||||
Returns:
|
||||
包含brief和summary的字典
|
||||
"""
|
||||
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
||||
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
||||
2. 记忆内容概括:对内容进行概括,保留重要信息,200字以内
|
||||
|
||||
内容:
|
||||
{content}
|
||||
|
||||
请按以下JSON格式输出:
|
||||
{{
|
||||
"brief": "记忆内容主题",
|
||||
"summary": "记忆内容概括"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
default_summary = {
|
||||
"brief": "主题未知的记忆",
|
||||
"summary": "无法概括的记忆内容",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM生成总结
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 使用repair_json解析响应
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
json_result = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
return default_summary
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
json_result = fixed_json_string
|
||||
|
||||
# 进行额外的类型检查
|
||||
if not isinstance(json_result, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
|
||||
return default_summary
|
||||
|
||||
# 确保所有必要字段都存在且类型正确
|
||||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||||
json_result["brief"] = "主题未知的记忆"
|
||||
|
||||
if "summary" not in json_result or not isinstance(json_result["summary"], str):
|
||||
json_result["summary"] = "无法概括的记忆内容"
|
||||
|
||||
return json_result
|
||||
|
||||
except Exception as json_error:
|
||||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||||
return default_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成总结时出错: {str(e)}")
|
||||
return default_summary
|
||||
|
||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||
"""
|
||||
使单个记忆衰减
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
|
||||
Returns:
|
||||
是否成功衰减
|
||||
"""
|
||||
memory_item = self.get_by_id(memory_id)
|
||||
if not memory_item:
|
||||
return False
|
||||
|
||||
# 计算衰减量(当前强度 * (1-衰减因子))
|
||||
old_strength = memory_item.memory_strength
|
||||
decay_amount = old_strength * (1 - decay_factor)
|
||||
|
||||
# 更新强度
|
||||
memory_item.memory_strength = decay_amount
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, memory_id: str) -> bool:
|
||||
"""
|
||||
删除指定ID的记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 要删除的记忆项ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
if memory_id not in self._id_map:
|
||||
return False
|
||||
|
||||
# 获取要删除的项
|
||||
self._id_map[memory_id]
|
||||
|
||||
# 从内存中删除
|
||||
self._memories = [i for i in self._memories if i.id != memory_id]
|
||||
|
||||
# 从ID映射中删除
|
||||
del self._id_map[memory_id]
|
||||
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有记忆"""
|
||||
self._memories.clear()
|
||||
self._id_map.clear()
|
||||
|
||||
async def merge_memories(
|
||||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||
) -> MemoryItem:
|
||||
"""
|
||||
合并两个记忆项
|
||||
|
||||
Args:
|
||||
memory_id1: 第一个记忆项ID
|
||||
memory_id2: 第二个记忆项ID
|
||||
reason: 合并原因
|
||||
delete_originals: 是否删除原始记忆,默认为True
|
||||
|
||||
Returns:
|
||||
合并后的记忆项
|
||||
"""
|
||||
# 获取两个记忆项
|
||||
memory_item1 = self.get_by_id(memory_id1)
|
||||
memory_item2 = self.get_by_id(memory_id2)
|
||||
|
||||
if not memory_item1 or not memory_item2:
|
||||
raise ValueError("无法找到指定的记忆项")
|
||||
|
||||
# 构建合并提示
|
||||
prompt = f"""
|
||||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||||
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
||||
|
||||
合并原因:{reason}
|
||||
|
||||
记忆1主题:{memory_item1.brief}
|
||||
记忆1内容:{memory_item1.summary}
|
||||
|
||||
记忆2主题:{memory_item2.brief}
|
||||
记忆2内容:{memory_item2.summary}
|
||||
|
||||
请按以下JSON格式输出合并结果:
|
||||
{{
|
||||
"brief": "合并后的主题(20字以内)",
|
||||
"summary": "合并后的内容概括(200字以内)"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
|
||||
# 默认合并结果
|
||||
default_merged = {
|
||||
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
|
||||
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM合并记忆
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 处理LLM返回的合并结果
|
||||
try:
|
||||
# 修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 将修复后的字符串解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
merged_data = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
merged_data = default_merged
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
merged_data = fixed_json_string
|
||||
|
||||
# 确保是字典类型
|
||||
if not isinstance(merged_data, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
|
||||
merged_data = default_merged
|
||||
|
||||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||||
merged_data["brief"] = default_merged["brief"]
|
||||
|
||||
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
|
||||
merged_data["summary"] = default_merged["summary"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆调用LLM出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
|
||||
# 创建新的记忆项
|
||||
# 取两个记忆项中更强的来源
|
||||
merged_source = (
|
||||
memory_item1.from_source
|
||||
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||||
else memory_item2.from_source
|
||||
)
|
||||
|
||||
# 创建新的记忆项
|
||||
merged_memory = MemoryItem(
|
||||
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
|
||||
)
|
||||
|
||||
# 记忆强度取两者最大值
|
||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||
|
||||
# 添加到存储中
|
||||
self.push_item(merged_memory)
|
||||
|
||||
# 如果需要,删除原始记忆
|
||||
if delete_originals:
|
||||
self.delete(memory_id1)
|
||||
self.delete(memory_id2)
|
||||
|
||||
return merged_memory
|
||||
|
||||
def delete_earliest_memory(self) -> bool:
|
||||
"""
|
||||
删除最早的记忆项
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
# 获取所有记忆项
|
||||
all_memories = self.get_all_items()
|
||||
|
||||
if not all_memories:
|
||||
return False
|
||||
|
||||
# 按时间戳排序,找到最早的记忆项
|
||||
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
|
||||
|
||||
# 删除最早的记忆项
|
||||
return self.delete(earliest_memory.id)
|
||||
156
src/chat/working_memory/working_memory.py
Normal file
156
src/chat/working_memory/working_memory.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import List, Any, Optional
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
工作记忆,负责协调和运作记忆
|
||||
从属于特定的流,用chat_id来标识
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
"""
|
||||
初始化工作记忆管理器
|
||||
|
||||
Args:
|
||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
||||
"""
|
||||
self.memory_manager = MemoryManager(chat_id)
|
||||
|
||||
# 记忆容量上限
|
||||
self.max_memories_per_chat = max_memories_per_chat
|
||||
|
||||
# 自动衰减间隔
|
||||
self.auto_decay_interval = auto_decay_interval
|
||||
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
|
||||
# 只有在工作记忆处理器启用时才启动自动衰减任务
|
||||
if global_config.focus_chat_processor.working_memory_processor:
|
||||
self._start_auto_decay()
|
||||
else:
|
||||
logger.debug(f"工作记忆处理器已禁用,跳过启动自动衰减任务 (chat_id: {chat_id})")
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
if self.decay_task is None:
|
||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
||||
|
||||
async def _auto_decay_loop(self):
|
||||
"""自动衰减循环"""
|
||||
while True:
|
||||
await asyncio.sleep(self.auto_decay_interval)
|
||||
try:
|
||||
await self.decay_all_memories()
|
||||
except Exception as e:
|
||||
print(f"自动衰减记忆时出错: {str(e)}")
|
||||
|
||||
async def add_memory(self, summary: Any, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
添加一段记忆到指定聊天
|
||||
|
||||
Args:
|
||||
summary: 记忆内容
|
||||
from_source: 数据来源
|
||||
|
||||
Returns:
|
||||
记忆项
|
||||
"""
|
||||
# 如果是字符串类型,生成总结
|
||||
|
||||
memory = MemoryItem(summary, from_source, brief)
|
||||
|
||||
# 添加到管理器
|
||||
self.memory_manager.push_item(memory)
|
||||
|
||||
# 如果超过最大记忆数量,删除最早的记忆
|
||||
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
||||
self.remove_earliest_memory()
|
||||
|
||||
return memory
|
||||
|
||||
def remove_earliest_memory(self):
|
||||
"""
|
||||
删除最早的记忆
|
||||
"""
|
||||
return self.memory_manager.delete_earliest_memory()
|
||||
|
||||
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
检索记忆
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
检索到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self.memory_manager.get_by_id(memory_id)
|
||||
if memory_item:
|
||||
memory_item.retrieval_count += 1
|
||||
memory_item.increase_strength(5)
|
||||
return memory_item
|
||||
return None
|
||||
|
||||
async def decay_all_memories(self, decay_factor: float = 0.5):
|
||||
"""
|
||||
对所有聊天的所有记忆进行衰减
|
||||
衰减:对记忆进行refine压缩,强度会变为原先的0.5
|
||||
|
||||
Args:
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
"""
|
||||
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
|
||||
|
||||
all_memories = self.memory_manager.get_all_items()
|
||||
|
||||
for memory_item in all_memories:
|
||||
# 如果压缩完小于1会被删除
|
||||
memory_id = memory_item.id
|
||||
self.memory_manager.decay_memory(memory_id, decay_factor)
|
||||
if memory_item.memory_strength < 1:
|
||||
self.memory_manager.delete(memory_id)
|
||||
continue
|
||||
# 计算衰减量
|
||||
# if memory_item.memory_strength < 5:
|
||||
# await self.memory_manager.refine_memory(
|
||||
# memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
|
||||
# )
|
||||
|
||||
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
||||
"""合并记忆
|
||||
|
||||
Args:
|
||||
memory_str: 记忆内容
|
||||
"""
|
||||
return await self.memory_manager.merge_memories(
|
||||
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭管理器,停止所有任务"""
|
||||
if self.decay_task and not self.decay_task.done():
|
||||
self.decay_task.cancel()
|
||||
try:
|
||||
await self.decay_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""
|
||||
获取所有记忆项目
|
||||
|
||||
Returns:
|
||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
||||
"""
|
||||
return self.memory_manager.get_all_items()
|
||||
261
src/chat/working_memory/working_memory_processor.py
Normal file
261
src/chat/working_memory/working_memory_processor.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.observation.observation import Observation
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from typing import List
|
||||
from src.chat.focus_chat.observation.working_observation import WorkingMemoryObservation
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
memory_proces_prompt = """
|
||||
你的名字是{bot_name}
|
||||
|
||||
现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
|
||||
{memory_str}
|
||||
|
||||
观察聊天内容和已经总结的记忆,思考如果有相近的记忆,请合并记忆,输出merge_memory,
|
||||
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
|
||||
|
||||
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
|
||||
```json
|
||||
{{
|
||||
"selected_memory_ids": ["id1", "id2", ...]
|
||||
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
|
||||
}}
|
||||
```
|
||||
"""
|
||||
Prompt(memory_proces_prompt, "prompt_memory_proces")
|
||||
|
||||
|
||||
class WorkingMemoryProcessor:
|
||||
log_prefix = "工作记忆"
|
||||
|
||||
def __init__(self, subheartflow_id: str):
|
||||
self.subheartflow_id = subheartflow_id
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
*infos: 可变数量的InfoBase类型的信息对象
|
||||
|
||||
Returns:
|
||||
List[InfoBase]: 处理后的结构化信息列表
|
||||
"""
|
||||
working_memory = None
|
||||
chat_info = ""
|
||||
chat_obs = None
|
||||
try:
|
||||
for observation in observations:
|
||||
if isinstance(observation, WorkingMemoryObservation):
|
||||
working_memory = observation.get_observe_info()
|
||||
if isinstance(observation, ChattingObservation):
|
||||
chat_info = observation.get_observe_info()
|
||||
chat_obs = observation
|
||||
# 检查是否有待压缩内容
|
||||
if chat_obs and chat_obs.compressor_prompt:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
|
||||
await self.compress_chat_memory(working_memory, chat_obs)
|
||||
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆观察,跳过处理")
|
||||
return []
|
||||
|
||||
all_memory = working_memory.get_all_memories()
|
||||
if not all_memory:
|
||||
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
|
||||
return []
|
||||
|
||||
memory_prompts = []
|
||||
for memory in all_memory:
|
||||
memory_id = memory.id
|
||||
memory_brief = memory.brief
|
||||
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
||||
memory_prompts.append(memory_single_prompt)
|
||||
|
||||
memory_choose_str = "".join(memory_prompts)
|
||||
|
||||
# 使用提示模板进行处理
|
||||
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
chat_observe_info=chat_info,
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(f"prompt: {prompt}---------------------------------")
|
||||
# print(f"content: {content}---------------------------------")
|
||||
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
# 解析LLM返回的JSON
|
||||
try:
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}")
|
||||
return []
|
||||
|
||||
selected_memory_ids = result.get("selected_memory_ids", [])
|
||||
merge_memory = result.get("merge_memory", [])
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}"
|
||||
)
|
||||
|
||||
# 根据selected_memory_ids,调取记忆
|
||||
memory_str = ""
|
||||
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
||||
|
||||
# 遍历所有记忆
|
||||
for memory in all_memory:
|
||||
if memory.id in selected_ids:
|
||||
# 选中的记忆显示详细内容
|
||||
memory = await working_memory.retrieve_memory(memory.id)
|
||||
if memory:
|
||||
memory_str += f"{memory.summary}\n"
|
||||
else:
|
||||
# 未选中的记忆显示梗概
|
||||
memory_str += f"{memory.brief}\n"
|
||||
|
||||
working_memory_info = WorkingMemoryInfo()
|
||||
if memory_str:
|
||||
working_memory_info.add_working_memory(memory_str)
|
||||
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
|
||||
|
||||
if merge_memory:
|
||||
for merge_pairs in merge_memory:
|
||||
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
|
||||
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
|
||||
if memory1 and memory2:
|
||||
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
|
||||
|
||||
return [working_memory_info]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
|
||||
"""压缩聊天记忆
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
obs: 聊天观察对象
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法压缩聊天记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
|
||||
if not summary_result:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
||||
return
|
||||
|
||||
print(f"compressor_prompt: {obs.compressor_prompt}")
|
||||
print(f"summary_result: {summary_result}")
|
||||
|
||||
# 修复并解析JSON
|
||||
try:
|
||||
fixed_json = repair_json(summary_result)
|
||||
summary_data = json.loads(fixed_json)
|
||||
|
||||
if not isinstance(summary_data, dict):
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
||||
return
|
||||
|
||||
theme = summary_data.get("theme", "")
|
||||
content = summary_data.get("content", "")
|
||||
|
||||
if not theme or not content:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
||||
return
|
||||
|
||||
# 创建新记忆
|
||||
await working_memory.add_memory(from_source="chat_compress", summary=content, brief=theme)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
# 清理压缩状态
|
||||
obs.compressor_prompt = ""
|
||||
obs.oldest_messages = []
|
||||
obs.oldest_messages_str = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
|
||||
"""异步合并记忆,不阻塞主流程
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
memory_id1: 第一个记忆ID
|
||||
memory_id2: 第二个记忆ID
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法合并记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
init_prompt()
|
||||
Reference in New Issue
Block a user