fix:ruff
This commit is contained in:
@@ -1,23 +1,16 @@
|
||||
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
|
||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||
import time
|
||||
import uuid
|
||||
import traceback
|
||||
import random
|
||||
import string
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
||||
|
||||
|
||||
def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
||||
"""
|
||||
初始化记忆项
|
||||
|
||||
|
||||
Args:
|
||||
data: 记忆数据
|
||||
from_source: 数据来源
|
||||
@@ -25,7 +18,7 @@ class MemoryItem:
|
||||
"""
|
||||
# 生成可读ID:时间戳_随机字符串
|
||||
timestamp = int(time.time())
|
||||
random_str = ''.join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||
self.id = f"{timestamp}_{random_str}"
|
||||
self.data = data
|
||||
self.data_type = type(data)
|
||||
@@ -40,63 +33,63 @@ class MemoryItem:
|
||||
# "events": ["事件1", "事件2"]
|
||||
# }
|
||||
self.summary = None
|
||||
|
||||
|
||||
# 记忆精简次数
|
||||
self.compress_count = 0
|
||||
|
||||
|
||||
# 记忆提取次数
|
||||
self.retrieval_count = 0
|
||||
|
||||
|
||||
# 记忆强度 (初始为10)
|
||||
self.memory_strength = 10.0
|
||||
|
||||
|
||||
# 记忆操作历史记录
|
||||
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
|
||||
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
|
||||
|
||||
|
||||
def add_tag(self, tag: str) -> None:
|
||||
"""添加标签"""
|
||||
self.tags.add(tag)
|
||||
|
||||
|
||||
def remove_tag(self, tag: str) -> None:
|
||||
"""移除标签"""
|
||||
if tag in self.tags:
|
||||
self.tags.remove(tag)
|
||||
|
||||
|
||||
def has_tag(self, tag: str) -> bool:
|
||||
"""检查是否有特定标签"""
|
||||
return tag in self.tags
|
||||
|
||||
|
||||
def has_all_tags(self, tags: List[str]) -> bool:
|
||||
"""检查是否有所有指定的标签"""
|
||||
return all(tag in self.tags for tag in tags)
|
||||
|
||||
|
||||
def matches_source(self, source: str) -> bool:
|
||||
"""检查来源是否匹配"""
|
||||
return self.from_source == source
|
||||
|
||||
|
||||
def set_summary(self, summary: Dict[str, Any]) -> None:
|
||||
"""设置总结信息"""
|
||||
self.summary = summary
|
||||
|
||||
|
||||
def increase_strength(self, amount: float) -> None:
|
||||
"""增加记忆强度"""
|
||||
self.memory_strength = min(10.0, self.memory_strength + amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("strengthen")
|
||||
|
||||
|
||||
def decrease_strength(self, amount: float) -> None:
|
||||
"""减少记忆强度"""
|
||||
self.memory_strength = max(0.1, self.memory_strength - amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("weaken")
|
||||
|
||||
|
||||
def increase_compress_count(self) -> None:
|
||||
"""增加精简次数并减弱记忆强度"""
|
||||
self.compress_count += 1
|
||||
# 记录操作历史
|
||||
self.record_operation("compress")
|
||||
|
||||
|
||||
def record_retrieval(self) -> None:
|
||||
"""记录记忆被提取的情况"""
|
||||
self.retrieval_count += 1
|
||||
@@ -104,16 +97,16 @@ class MemoryItem:
|
||||
self.memory_strength = min(10.0, self.memory_strength * 2)
|
||||
# 记录操作历史
|
||||
self.record_operation("retrieval")
|
||||
|
||||
|
||||
def record_operation(self, operation_type: str) -> None:
|
||||
"""记录操作历史"""
|
||||
current_time = time.time()
|
||||
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
||||
|
||||
|
||||
def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]:
|
||||
"""转换为元组格式(为了兼容性)"""
|
||||
return (self.data, self.from_source, self.tags, self.timestamp, self.id)
|
||||
|
||||
|
||||
def is_memory_valid(self) -> bool:
|
||||
"""检查记忆是否有效(强度是否大于等于1)"""
|
||||
return self.memory_strength >= 1.0
|
||||
return self.memory_strength >= 1.0
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Type, TypeVar, List, Optional
|
||||
import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
@@ -14,74 +12,71 @@ import json # 添加json模块导入
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("working_memory")
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化工作记忆
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
|
||||
"""
|
||||
# 关联的聊天ID
|
||||
self._chat_id = chat_id
|
||||
|
||||
|
||||
# 主存储: 数据类型 -> 记忆项列表
|
||||
self._memory: Dict[Type, List[MemoryItem]] = {}
|
||||
|
||||
|
||||
# ID到记忆项的映射
|
||||
self._id_map: Dict[str, MemoryItem] = {}
|
||||
|
||||
|
||||
self.llm_summarizer = LLMRequest(
|
||||
model=global_config.llm_summary,
|
||||
temperature=0.3,
|
||||
max_tokens=512,
|
||||
request_type="memory_summarization"
|
||||
model=global_config.llm_summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
|
||||
)
|
||||
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
"""获取关联的聊天ID"""
|
||||
return self._chat_id
|
||||
|
||||
|
||||
@chat_id.setter
|
||||
def chat_id(self, value: str):
|
||||
"""设置关联的聊天ID"""
|
||||
self._chat_id = value
|
||||
|
||||
|
||||
def push_item(self, memory_item: MemoryItem) -> str:
|
||||
"""
|
||||
推送一个已创建的记忆项到工作记忆中
|
||||
|
||||
|
||||
Args:
|
||||
memory_item: 要存储的记忆项
|
||||
|
||||
|
||||
Returns:
|
||||
记忆项的ID
|
||||
"""
|
||||
data_type = memory_item.data_type
|
||||
|
||||
|
||||
# 确保存在该类型的存储列表
|
||||
if data_type not in self._memory:
|
||||
self._memory[data_type] = []
|
||||
|
||||
|
||||
# 添加到内存和ID映射
|
||||
self._memory[data_type].append(memory_item)
|
||||
self._id_map[memory_item.id] = memory_item
|
||||
|
||||
|
||||
return memory_item.id
|
||||
|
||||
|
||||
async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem:
|
||||
"""
|
||||
推送一段有类型的信息到工作记忆中,并自动生成总结
|
||||
|
||||
|
||||
Args:
|
||||
data: 要存储的数据
|
||||
from_source: 数据来源
|
||||
tags: 数据标签列表
|
||||
|
||||
|
||||
Returns:
|
||||
包含原始数据和总结信息的字典
|
||||
"""
|
||||
@@ -89,65 +84,66 @@ class MemoryManager:
|
||||
if isinstance(data, str):
|
||||
# 先生成总结
|
||||
summary = await self.summarize_memory_item(data)
|
||||
|
||||
|
||||
# 准备标签
|
||||
memory_tags = list(tags) if tags else []
|
||||
|
||||
|
||||
# 创建记忆项
|
||||
memory_item = MemoryItem(data, from_source, memory_tags)
|
||||
|
||||
|
||||
# 将总结信息保存到记忆项中
|
||||
memory_item.set_summary(summary)
|
||||
|
||||
|
||||
# 推送记忆项
|
||||
self.push_item(memory_item)
|
||||
|
||||
|
||||
return memory_item
|
||||
else:
|
||||
# 非字符串类型,直接创建并推送记忆项
|
||||
memory_item = MemoryItem(data, from_source, tags)
|
||||
self.push_item(memory_item)
|
||||
|
||||
|
||||
return memory_item
|
||||
|
||||
|
||||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
通过ID获取记忆项
|
||||
|
||||
|
||||
Args:
|
||||
memory_id: 记忆项ID
|
||||
|
||||
|
||||
Returns:
|
||||
找到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self._id_map.get(memory_id)
|
||||
if memory_item:
|
||||
|
||||
# 检查记忆强度,如果小于1则删除
|
||||
if not memory_item.is_memory_valid():
|
||||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||||
self.delete(memory_id)
|
||||
return None
|
||||
|
||||
|
||||
return memory_item
|
||||
|
||||
|
||||
def get_all_items(self) -> List[MemoryItem]:
|
||||
"""获取所有记忆项"""
|
||||
return list(self._id_map.values())
|
||||
|
||||
def find_items(self,
|
||||
data_type: Optional[Type] = None,
|
||||
source: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
memory_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
newest_first: bool = False,
|
||||
min_strength: float = 0.0) -> List[MemoryItem]:
|
||||
|
||||
def find_items(
|
||||
self,
|
||||
data_type: Optional[Type] = None,
|
||||
source: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
memory_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
newest_first: bool = False,
|
||||
min_strength: float = 0.0,
|
||||
) -> List[MemoryItem]:
|
||||
"""
|
||||
按条件查找记忆项
|
||||
|
||||
|
||||
Args:
|
||||
data_type: 要查找的数据类型
|
||||
source: 数据来源
|
||||
@@ -158,7 +154,7 @@ class MemoryManager:
|
||||
limit: 返回结果的最大数量
|
||||
newest_first: 是否按最新优先排序
|
||||
min_strength: 最小记忆强度
|
||||
|
||||
|
||||
Returns:
|
||||
符合条件的记忆项列表
|
||||
"""
|
||||
@@ -166,62 +162,62 @@ class MemoryManager:
|
||||
if memory_id:
|
||||
item = self.get_by_id(memory_id)
|
||||
return [item] if item else []
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
# 确定要搜索的类型列表
|
||||
types_to_search = [data_type] if data_type else list(self._memory.keys())
|
||||
|
||||
|
||||
# 对每个类型进行搜索
|
||||
for typ in types_to_search:
|
||||
if typ not in self._memory:
|
||||
continue
|
||||
|
||||
|
||||
# 获取该类型的所有项目
|
||||
items = self._memory[typ]
|
||||
|
||||
|
||||
# 如果需要最新优先,则反转遍历顺序
|
||||
if newest_first:
|
||||
items_to_check = list(reversed(items))
|
||||
else:
|
||||
items_to_check = items
|
||||
|
||||
|
||||
# 遍历项目
|
||||
for item in items_to_check:
|
||||
# 检查来源是否匹配
|
||||
if source is not None and not item.matches_source(source):
|
||||
continue
|
||||
|
||||
|
||||
# 检查标签是否匹配
|
||||
if tags is not None and not item.has_all_tags(tags):
|
||||
continue
|
||||
|
||||
|
||||
# 检查时间范围
|
||||
if start_time is not None and item.timestamp < start_time:
|
||||
continue
|
||||
if end_time is not None and item.timestamp > end_time:
|
||||
continue
|
||||
|
||||
|
||||
# 检查记忆强度
|
||||
if min_strength > 0 and item.memory_strength < min_strength:
|
||||
continue
|
||||
|
||||
|
||||
# 所有条件都满足,添加到结果中
|
||||
results.append(item)
|
||||
|
||||
|
||||
# 如果达到限制数量,提前返回
|
||||
if limit is not None and len(results) >= limit:
|
||||
return results
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
使用LLM总结记忆项
|
||||
|
||||
|
||||
Args:
|
||||
content: 需要总结的内容
|
||||
|
||||
|
||||
Returns:
|
||||
包含总结、概括、关键概念和事件的字典
|
||||
"""
|
||||
@@ -257,18 +253,18 @@ class MemoryManager:
|
||||
"brief": "主题未知的记忆",
|
||||
"detailed": "大致内容未知的记忆",
|
||||
"keypoints": ["未知的概念"],
|
||||
"events": ["未知的事件"]
|
||||
"events": ["未知的事件"],
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 调用LLM生成总结
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
|
||||
# 使用repair_json解析响应
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
@@ -279,68 +275,60 @@ class MemoryManager:
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
json_result = fixed_json_string
|
||||
|
||||
|
||||
# 进行额外的类型检查
|
||||
if not isinstance(json_result, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
|
||||
return default_summary
|
||||
|
||||
|
||||
# 确保所有必要字段都存在且类型正确
|
||||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||||
json_result["brief"] = "主题未知的记忆"
|
||||
|
||||
|
||||
if "detailed" not in json_result or not isinstance(json_result["detailed"], str):
|
||||
json_result["detailed"] = "大致内容未知的记忆"
|
||||
|
||||
|
||||
# 处理关键概念
|
||||
if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list):
|
||||
json_result["keypoints"] = ["未知的概念"]
|
||||
else:
|
||||
# 确保keypoints中的每个项目都是字符串
|
||||
json_result["keypoints"] = [
|
||||
str(point) for point in json_result["keypoints"]
|
||||
if point is not None
|
||||
]
|
||||
json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
|
||||
if not json_result["keypoints"]:
|
||||
json_result["keypoints"] = ["未知的概念"]
|
||||
|
||||
|
||||
# 处理事件
|
||||
if "events" not in json_result or not isinstance(json_result["events"], list):
|
||||
json_result["events"] = ["未知的事件"]
|
||||
else:
|
||||
# 确保events中的每个项目都是字符串
|
||||
json_result["events"] = [
|
||||
str(event) for event in json_result["events"]
|
||||
if event is not None
|
||||
]
|
||||
json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
|
||||
if not json_result["events"]:
|
||||
json_result["events"] = ["未知的事件"]
|
||||
|
||||
|
||||
# 兼容旧版,将keypoints和events合并到key_points中
|
||||
json_result["key_points"] = json_result["keypoints"] + json_result["events"]
|
||||
|
||||
|
||||
return json_result
|
||||
|
||||
|
||||
except Exception as json_error:
|
||||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||||
# 返回默认结构
|
||||
return default_summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 出错时返回简单的结构
|
||||
logger.error(f"生成总结时出错: {str(e)}")
|
||||
return default_summary
|
||||
|
||||
async def refine_memory(self,
|
||||
memory_id: str,
|
||||
requirements: str = "") -> Dict[str, Any]:
|
||||
|
||||
async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
|
||||
"""
|
||||
对记忆进行精简操作,根据要求修改要点、总结和概括
|
||||
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
|
||||
|
||||
|
||||
Returns:
|
||||
修改后的记忆总结字典
|
||||
"""
|
||||
@@ -349,12 +337,12 @@ class MemoryManager:
|
||||
memory_item = self.get_by_id(memory_id)
|
||||
if not memory_item:
|
||||
raise ValueError(f"未找到ID为{memory_id}的记忆项")
|
||||
|
||||
|
||||
# 增加精简次数
|
||||
memory_item.increase_compress_count()
|
||||
|
||||
|
||||
summary = memory_item.summary
|
||||
|
||||
|
||||
# 使用LLM根据要求对总结、概括和要点进行精简修改
|
||||
prompt = f"""
|
||||
请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程:
|
||||
@@ -396,15 +384,15 @@ class MemoryManager:
|
||||
halfway = len(key_points) // 2
|
||||
summary["keypoints"] = key_points[:halfway] or ["未知的概念"]
|
||||
summary["events"] = key_points[halfway:] or ["未知的事件"]
|
||||
|
||||
|
||||
# 定义默认的精简结果
|
||||
default_refined = {
|
||||
"brief": summary["brief"],
|
||||
"detailed": summary["detailed"],
|
||||
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
|
||||
"events": summary.get("events", ["未知的事件"])[:1] # 默认只保留第一个事件
|
||||
"events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# 调用LLM修改总结、概括和要点
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
@@ -413,7 +401,7 @@ class MemoryManager:
|
||||
try:
|
||||
# 修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
|
||||
# 将修复后的字符串解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
@@ -424,16 +412,16 @@ class MemoryManager:
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
refined_data = fixed_json_string
|
||||
|
||||
|
||||
# 确保是字典类型
|
||||
if not isinstance(refined_data, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
|
||||
refined_data = default_refined
|
||||
|
||||
|
||||
# 更新总结、概括
|
||||
summary["brief"] = refined_data.get("brief", "主题未知的记忆")
|
||||
summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆")
|
||||
|
||||
|
||||
# 更新关键概念
|
||||
keypoints = refined_data.get("keypoints", [])
|
||||
if isinstance(keypoints, list) and keypoints:
|
||||
@@ -442,7 +430,7 @@ class MemoryManager:
|
||||
else:
|
||||
# 如果keypoints不是列表或为空,使用默认值
|
||||
summary["keypoints"] = ["主要概念已遗忘"]
|
||||
|
||||
|
||||
# 更新事件
|
||||
events = refined_data.get("events", [])
|
||||
if isinstance(events, list) and events:
|
||||
@@ -451,84 +439,83 @@ class MemoryManager:
|
||||
else:
|
||||
# 如果events不是列表或为空,使用默认值
|
||||
summary["events"] = ["事件细节已遗忘"]
|
||||
|
||||
|
||||
# 兼容旧版,维护key_points
|
||||
summary["key_points"] = summary["keypoints"] + summary["events"]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"精简记忆出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 出错时使用简化的默认精简
|
||||
summary["brief"] = summary["brief"] + " (已简化)"
|
||||
summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1]
|
||||
summary["events"] = summary.get("events", ["未知的事件"])[:1]
|
||||
summary["key_points"] = summary["keypoints"] + summary["events"]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"精简记忆调用LLM出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# 更新原记忆项的总结
|
||||
memory_item.set_summary(summary)
|
||||
|
||||
|
||||
return memory_item
|
||||
|
||||
|
||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||
"""
|
||||
使单个记忆衰减
|
||||
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功衰减
|
||||
"""
|
||||
memory_item = self.get_by_id(memory_id)
|
||||
if not memory_item:
|
||||
return False
|
||||
|
||||
|
||||
# 计算衰减量(当前强度 * (1-衰减因子))
|
||||
old_strength = memory_item.memory_strength
|
||||
decay_amount = old_strength * (1 - decay_factor)
|
||||
|
||||
|
||||
# 更新强度
|
||||
memory_item.memory_strength = decay_amount
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def delete(self, memory_id: str) -> bool:
|
||||
"""
|
||||
删除指定ID的记忆项
|
||||
|
||||
|
||||
Args:
|
||||
memory_id: 要删除的记忆项ID
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
if memory_id not in self._id_map:
|
||||
return False
|
||||
|
||||
|
||||
# 获取要删除的项
|
||||
item = self._id_map[memory_id]
|
||||
|
||||
|
||||
# 从内存中删除
|
||||
data_type = item.data_type
|
||||
if data_type in self._memory:
|
||||
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
|
||||
|
||||
|
||||
# 从ID映射中删除
|
||||
del self._id_map[memory_id]
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def clear(self, data_type: Optional[Type] = None) -> None:
|
||||
"""
|
||||
清除记忆中的数据
|
||||
|
||||
|
||||
Args:
|
||||
data_type: 要清除的数据类型,如果为None则清除所有数据
|
||||
"""
|
||||
@@ -542,34 +529,36 @@ class MemoryManager:
|
||||
if item.id in self._id_map:
|
||||
del self._id_map[item.id]
|
||||
del self._memory[data_type]
|
||||
|
||||
async def merge_memories(self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True) -> MemoryItem:
|
||||
|
||||
async def merge_memories(
|
||||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||
) -> MemoryItem:
|
||||
"""
|
||||
合并两个记忆项
|
||||
|
||||
|
||||
Args:
|
||||
memory_id1: 第一个记忆项ID
|
||||
memory_id2: 第二个记忆项ID
|
||||
reason: 合并原因
|
||||
delete_originals: 是否删除原始记忆,默认为True
|
||||
|
||||
|
||||
Returns:
|
||||
包含合并后的记忆信息的字典
|
||||
"""
|
||||
# 获取两个记忆项
|
||||
memory_item1 = self.get_by_id(memory_id1)
|
||||
memory_item2 = self.get_by_id(memory_id2)
|
||||
|
||||
|
||||
if not memory_item1 or not memory_item2:
|
||||
raise ValueError("无法找到指定的记忆项")
|
||||
|
||||
|
||||
content1 = memory_item1.data
|
||||
content2 = memory_item2.data
|
||||
|
||||
|
||||
# 获取记忆的摘要信息(如果有)
|
||||
summary1 = memory_item1.summary
|
||||
summary2 = memory_item2.summary
|
||||
|
||||
|
||||
# 构建合并提示
|
||||
prompt = f"""
|
||||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||||
@@ -577,32 +566,32 @@ class MemoryManager:
|
||||
|
||||
合并原因:{reason}
|
||||
"""
|
||||
|
||||
|
||||
# 如果有摘要信息,添加到提示中
|
||||
if summary1:
|
||||
prompt += f"记忆1主题:{summary1['brief']}\n"
|
||||
prompt += f"记忆1概括:{summary1['detailed']}\n"
|
||||
|
||||
|
||||
if "keypoints" in summary1:
|
||||
prompt += f"记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1['keypoints']]) + "\n\n"
|
||||
|
||||
prompt += "记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
|
||||
|
||||
if "events" in summary1:
|
||||
prompt += f"记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1['events']]) + "\n\n"
|
||||
prompt += "记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
|
||||
elif "key_points" in summary1:
|
||||
prompt += f"记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1['key_points']]) + "\n\n"
|
||||
|
||||
prompt += "记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
|
||||
|
||||
if summary2:
|
||||
prompt += f"记忆2主题:{summary2['brief']}\n"
|
||||
prompt += f"记忆2概括:{summary2['detailed']}\n"
|
||||
|
||||
|
||||
if "keypoints" in summary2:
|
||||
prompt += f"记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2['keypoints']]) + "\n\n"
|
||||
|
||||
prompt += "记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
|
||||
|
||||
if "events" in summary2:
|
||||
prompt += f"记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2['events']]) + "\n\n"
|
||||
prompt += "记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
|
||||
elif "key_points" in summary2:
|
||||
prompt += f"记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2['key_points']]) + "\n\n"
|
||||
|
||||
prompt += "记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
|
||||
|
||||
# 添加记忆原始内容
|
||||
prompt += f"""
|
||||
记忆1原始内容:
|
||||
@@ -630,16 +619,16 @@ class MemoryManager:
|
||||
```
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
|
||||
|
||||
# 默认合并结果
|
||||
default_merged = {
|
||||
"content": f"{content1}\n\n{content2}",
|
||||
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
|
||||
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
|
||||
"keypoints": [],
|
||||
"events": []
|
||||
"events": [],
|
||||
}
|
||||
|
||||
|
||||
# 合并旧版key_points
|
||||
if "key_points" in summary1:
|
||||
default_merged["keypoints"].extend(summary1.get("keypoints", []))
|
||||
@@ -650,7 +639,7 @@ class MemoryManager:
|
||||
halfway = len(key_points) // 2
|
||||
default_merged["keypoints"].extend(key_points[:halfway])
|
||||
default_merged["events"].extend(key_points[halfway:])
|
||||
|
||||
|
||||
if "key_points" in summary2:
|
||||
default_merged["keypoints"].extend(summary2.get("keypoints", []))
|
||||
default_merged["events"].extend(summary2.get("events", []))
|
||||
@@ -660,25 +649,25 @@ class MemoryManager:
|
||||
halfway = len(key_points) // 2
|
||||
default_merged["keypoints"].extend(key_points[:halfway])
|
||||
default_merged["events"].extend(key_points[halfway:])
|
||||
|
||||
|
||||
# 确保列表不为空
|
||||
if not default_merged["keypoints"]:
|
||||
default_merged["keypoints"] = ["合并的关键概念"]
|
||||
if not default_merged["events"]:
|
||||
default_merged["events"] = ["合并的事件"]
|
||||
|
||||
|
||||
# 添加key_points兼容
|
||||
default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"]
|
||||
|
||||
|
||||
try:
|
||||
# 调用LLM合并记忆
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
|
||||
# 处理LLM返回的合并结果
|
||||
try:
|
||||
# 修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
|
||||
# 将修复后的字符串解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
@@ -689,49 +678,43 @@ class MemoryManager:
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
merged_data = fixed_json_string
|
||||
|
||||
|
||||
# 确保是字典类型
|
||||
if not isinstance(merged_data, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
|
||||
merged_data = default_merged
|
||||
|
||||
|
||||
# 确保所有必要字段都存在且类型正确
|
||||
if "content" not in merged_data or not isinstance(merged_data["content"], str):
|
||||
merged_data["content"] = default_merged["content"]
|
||||
|
||||
|
||||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||||
merged_data["brief"] = default_merged["brief"]
|
||||
|
||||
|
||||
if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str):
|
||||
merged_data["detailed"] = default_merged["detailed"]
|
||||
|
||||
|
||||
# 处理关键概念
|
||||
if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list):
|
||||
merged_data["keypoints"] = default_merged["keypoints"]
|
||||
else:
|
||||
# 确保keypoints中的每个项目都是字符串
|
||||
merged_data["keypoints"] = [
|
||||
str(point) for point in merged_data["keypoints"]
|
||||
if point is not None
|
||||
]
|
||||
merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
|
||||
if not merged_data["keypoints"]:
|
||||
merged_data["keypoints"] = ["合并的关键概念"]
|
||||
|
||||
|
||||
# 处理事件
|
||||
if "events" not in merged_data or not isinstance(merged_data["events"], list):
|
||||
merged_data["events"] = default_merged["events"]
|
||||
else:
|
||||
# 确保events中的每个项目都是字符串
|
||||
merged_data["events"] = [
|
||||
str(event) for event in merged_data["events"]
|
||||
if event is not None
|
||||
]
|
||||
merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
|
||||
if not merged_data["events"]:
|
||||
merged_data["events"] = ["合并的事件"]
|
||||
|
||||
|
||||
# 添加key_points兼容
|
||||
merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"]
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
@@ -740,59 +723,59 @@ class MemoryManager:
|
||||
logger.error(f"合并记忆调用LLM出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
|
||||
|
||||
# 创建新的记忆项
|
||||
# 合并记忆项的标签
|
||||
merged_tags = memory_item1.tags.union(memory_item2.tags)
|
||||
|
||||
|
||||
# 取两个记忆项中更强的来源
|
||||
merged_source = memory_item1.from_source if memory_item1.memory_strength >= memory_item2.memory_strength else memory_item2.from_source
|
||||
|
||||
# 创建新的记忆项
|
||||
merged_memory = MemoryItem(
|
||||
data=merged_data["content"],
|
||||
from_source=merged_source,
|
||||
tags=list(merged_tags)
|
||||
merged_source = (
|
||||
memory_item1.from_source
|
||||
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||||
else memory_item2.from_source
|
||||
)
|
||||
|
||||
|
||||
# 创建新的记忆项
|
||||
merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
|
||||
|
||||
# 设置合并后的摘要
|
||||
summary = {
|
||||
"brief": merged_data["brief"],
|
||||
"detailed": merged_data["detailed"],
|
||||
"keypoints": merged_data["keypoints"],
|
||||
"events": merged_data["events"],
|
||||
"key_points": merged_data["key_points"]
|
||||
"key_points": merged_data["key_points"],
|
||||
}
|
||||
merged_memory.set_summary(summary)
|
||||
|
||||
|
||||
# 记忆强度取两者最大值
|
||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||
|
||||
|
||||
# 添加到存储中
|
||||
self.push_item(merged_memory)
|
||||
|
||||
|
||||
# 如果需要,删除原始记忆
|
||||
if delete_originals:
|
||||
self.delete(memory_id1)
|
||||
self.delete(memory_id2)
|
||||
|
||||
|
||||
return merged_memory
|
||||
|
||||
|
||||
def delete_earliest_memory(self) -> bool:
|
||||
"""
|
||||
删除最早的记忆项
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
# 获取所有记忆项
|
||||
all_memories = self.get_all_items()
|
||||
|
||||
|
||||
if not all_memories:
|
||||
return False
|
||||
|
||||
|
||||
# 按时间戳排序,找到最早的记忆项
|
||||
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
|
||||
|
||||
|
||||
# 删除最早的记忆项
|
||||
return self.delete(earliest_memory.id)
|
||||
return self.delete(earliest_memory.id)
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("memory_loader")
|
||||
|
||||
class MemoryFileLoader:
|
||||
"""从文件加载记忆内容的工具类"""
|
||||
|
||||
def __init__(self, working_memory: WorkingMemory):
|
||||
"""
|
||||
初始化记忆文件加载器
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆实例
|
||||
"""
|
||||
self.working_memory = working_memory
|
||||
|
||||
async def load_from_directory(self,
|
||||
directory_path: str,
|
||||
file_pattern: str = "*.txt",
|
||||
common_tags: List[str] = None,
|
||||
source_prefix: str = "文件") -> List[MemoryItem]:
|
||||
"""
|
||||
从指定目录加载符合模式的文件作为记忆
|
||||
|
||||
Args:
|
||||
directory_path: 目录路径
|
||||
file_pattern: 文件模式(默认为*.txt)
|
||||
common_tags: 所有记忆共有的标签
|
||||
source_prefix: 来源前缀
|
||||
|
||||
Returns:
|
||||
加载的记忆项列表
|
||||
"""
|
||||
directory = Path(directory_path)
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
logger.error(f"目录不存在或不是有效目录: {directory_path}")
|
||||
return []
|
||||
|
||||
# 获取文件列表
|
||||
files = list(directory.glob(file_pattern))
|
||||
if not files:
|
||||
logger.warning(f"在目录 {directory_path} 中没有找到符合 {file_pattern} 的文件")
|
||||
return []
|
||||
|
||||
logger.info(f"在目录 {directory_path} 中找到 {len(files)} 个符合条件的文件")
|
||||
|
||||
# 加载文件内容为记忆
|
||||
loaded_memories = []
|
||||
for file_path in files:
|
||||
try:
|
||||
memory_item = await self._load_single_file(
|
||||
file_path=str(file_path),
|
||||
common_tags=common_tags,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
if memory_item:
|
||||
loaded_memories.append(memory_item)
|
||||
logger.info(f"成功加载记忆: {file_path.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
|
||||
|
||||
logger.info(f"完成加载,共加载了 {len(loaded_memories)} 个记忆")
|
||||
return loaded_memories
|
||||
|
||||
async def _load_single_file(self,
|
||||
file_path: str,
|
||||
common_tags: Optional[List[str]] = None,
|
||||
source_prefix: str = "文件") -> Optional[MemoryItem]:
|
||||
"""
|
||||
加载单个文件作为记忆
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
common_tags: 记忆共有的标签
|
||||
source_prefix: 来源前缀
|
||||
|
||||
Returns:
|
||||
记忆项,加载失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 读取文件内容
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.warning(f"文件 {file_path} 内容为空")
|
||||
return None
|
||||
|
||||
# 准备标签和来源
|
||||
file_name = os.path.basename(file_path)
|
||||
tags = list(common_tags) if common_tags else []
|
||||
tags.append(file_name) # 添加文件名作为标签
|
||||
|
||||
source = f"{source_prefix}_{file_name}"
|
||||
|
||||
# 添加到工作记忆
|
||||
memory = await self.working_memory.add_memory(
|
||||
content=content,
|
||||
from_source=source,
|
||||
tags=tags
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def main():
|
||||
"""示例使用"""
|
||||
# 初始化工作记忆
|
||||
chat_id = "demo_chat"
|
||||
working_memory = WorkingMemory(chat_id=chat_id)
|
||||
|
||||
try:
|
||||
# 初始化加载器
|
||||
loader = MemoryFileLoader(working_memory)
|
||||
|
||||
# 加载当前目录中的txt文件
|
||||
current_dir = Path(__file__).parent
|
||||
memories = await loader.load_from_directory(
|
||||
directory_path=str(current_dir),
|
||||
file_pattern="*.txt",
|
||||
common_tags=["测试数据", "自动加载"],
|
||||
source_prefix="测试文件"
|
||||
)
|
||||
|
||||
# 显示加载结果
|
||||
print(f"共加载了 {len(memories)} 个记忆")
|
||||
|
||||
# 获取并显示所有记忆的概要
|
||||
all_memories = working_memory.memory_manager.get_all_items()
|
||||
for memory in all_memories:
|
||||
print("\n" + "=" * 40)
|
||||
print(f"记忆ID: {memory.id}")
|
||||
print(f"来源: {memory.from_source}")
|
||||
print(f"标签: {', '.join(memory.tags)}")
|
||||
|
||||
if memory.summary:
|
||||
print(f"\n主题: {memory.summary.get('brief', '无主题')}")
|
||||
print(f"概述: {memory.summary.get('detailed', '无概述')}")
|
||||
print("\n要点:")
|
||||
for point in memory.summary.get('key_points', []):
|
||||
print(f"- {point}")
|
||||
else:
|
||||
print("\n无摘要信息")
|
||||
|
||||
print("=" * 40)
|
||||
|
||||
finally:
|
||||
# 关闭工作记忆
|
||||
await working_memory.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行示例
|
||||
asyncio.run(main())
|
||||
@@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到系统路径
|
||||
current_dir = Path(__file__).parent
|
||||
project_root = current_dir.parent.parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
|
||||
async def test_load_memories_from_files():
|
||||
"""测试从文件加载记忆的功能"""
|
||||
print("开始测试从文件加载记忆...")
|
||||
|
||||
# 初始化工作记忆
|
||||
chat_id = "test_memory_load"
|
||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
||||
|
||||
try:
|
||||
# 获取测试文件列表
|
||||
test_dir = Path(__file__).parent
|
||||
test_files = [
|
||||
os.path.join(test_dir, f)
|
||||
for f in os.listdir(test_dir)
|
||||
if f.endswith(".txt")
|
||||
]
|
||||
|
||||
print(f"找到 {len(test_files)} 个测试文件")
|
||||
|
||||
# 从每个文件加载记忆
|
||||
for file_path in test_files:
|
||||
file_name = os.path.basename(file_path)
|
||||
print(f"从文件 {file_name} 加载记忆...")
|
||||
|
||||
# 读取文件内容
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 添加记忆
|
||||
memory = await working_memory.add_memory(
|
||||
content=content,
|
||||
from_source=f"文件_{file_name}",
|
||||
tags=["测试文件", file_name]
|
||||
)
|
||||
|
||||
print(f"已添加记忆: ID={memory.id}")
|
||||
if memory.summary:
|
||||
print(f"记忆概要: {memory.summary.get('brief', '无概要')}")
|
||||
print(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点']))}")
|
||||
print("-" * 50)
|
||||
|
||||
# 获取所有记忆
|
||||
all_memories = working_memory.memory_manager.get_all_items()
|
||||
print(f"\n成功加载 {len(all_memories)} 个记忆")
|
||||
|
||||
# 测试检索记忆
|
||||
if all_memories:
|
||||
print("\n测试检索第一个记忆...")
|
||||
first_memory = all_memories[0]
|
||||
retrieved = await working_memory.retrieve_memory(first_memory.id)
|
||||
|
||||
if retrieved:
|
||||
print(f"成功检索记忆: ID={retrieved.id}")
|
||||
print(f"检索后强度: {retrieved.memory_strength} (初始为10.0)")
|
||||
print(f"检索次数: {retrieved.retrieval_count}")
|
||||
else:
|
||||
print("检索失败")
|
||||
|
||||
# 测试记忆衰减
|
||||
print("\n测试记忆衰减...")
|
||||
for memory in all_memories:
|
||||
print(f"记忆 {memory.id} 衰减前强度: {memory.memory_strength}")
|
||||
|
||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
||||
|
||||
all_memories_after = working_memory.memory_manager.get_all_items()
|
||||
for memory in all_memories_after:
|
||||
print(f"记忆 {memory.id} 衰减后强度: {memory.memory_strength}")
|
||||
|
||||
finally:
|
||||
# 关闭工作记忆
|
||||
await working_memory.shutdown()
|
||||
print("\n测试完成,已关闭工作记忆")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 运行测试
|
||||
asyncio.run(test_load_memories_from_files())
|
||||
@@ -1,197 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import random
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到系统路径
|
||||
current_dir = Path(__file__).parent
|
||||
project_root = current_dir.parent.parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("real_usage_simulation")
|
||||
|
||||
class WorkingMemorySimulator:
|
||||
"""模拟工作记忆的真实使用场景"""
|
||||
|
||||
def __init__(self, chat_id="real_usage_test", cycle_interval=20):
|
||||
"""
|
||||
初始化模拟器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
cycle_interval: 循环间隔时间(秒)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.cycle_interval = cycle_interval
|
||||
self.working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=20, auto_decay_interval=60)
|
||||
self.cycle_count = 0
|
||||
self.running = False
|
||||
|
||||
# 获取测试文件路径
|
||||
self.test_files = self._get_test_files()
|
||||
if not self.test_files:
|
||||
raise FileNotFoundError("找不到测试文件,请确保test目录中有.txt文件")
|
||||
|
||||
# 存储所有添加的记忆ID
|
||||
self.memory_ids = []
|
||||
|
||||
async def start(self, total_cycles=5):
|
||||
"""
|
||||
开始模拟循环
|
||||
|
||||
Args:
|
||||
total_cycles: 总循环次数,设为None表示无限循环
|
||||
"""
|
||||
self.running = True
|
||||
logger.info(f"开始模拟真实使用场景,循环间隔: {self.cycle_interval}秒")
|
||||
|
||||
try:
|
||||
while self.running and (total_cycles is None or self.cycle_count < total_cycles):
|
||||
self.cycle_count += 1
|
||||
logger.info(f"\n===== 开始第 {self.cycle_count} 次循环 =====")
|
||||
|
||||
# 执行一次循环
|
||||
await self._run_one_cycle()
|
||||
|
||||
# 如果还有更多循环,则等待
|
||||
if self.running and (total_cycles is None or self.cycle_count < total_cycles):
|
||||
wait_time = self.cycle_interval
|
||||
logger.info(f"等待 {wait_time} 秒后开始下一循环...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
logger.info(f"模拟完成,共执行了 {self.cycle_count} 次循环")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("接收到中断信号,停止模拟")
|
||||
except Exception as e:
|
||||
logger.error(f"模拟过程中出错: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
# 关闭工作记忆
|
||||
await self.working_memory.shutdown()
|
||||
|
||||
def stop(self):
|
||||
"""停止模拟循环"""
|
||||
self.running = False
|
||||
logger.info("正在停止模拟...")
|
||||
|
||||
async def _run_one_cycle(self):
|
||||
"""运行一次完整循环:先检索记忆,再添加新记忆"""
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 先检索已有记忆(如果有)
|
||||
await self._retrieve_memories()
|
||||
|
||||
# 2. 添加新记忆
|
||||
await self._add_new_memory()
|
||||
|
||||
# 3. 显示工作记忆状态
|
||||
await self._show_memory_status()
|
||||
|
||||
# 计算循环耗时
|
||||
cycle_duration = time.time() - start_time
|
||||
logger.info(f"第 {self.cycle_count} 次循环完成,耗时: {cycle_duration:.2f}秒")
|
||||
|
||||
async def _retrieve_memories(self):
|
||||
"""检索现有记忆"""
|
||||
# 如果有已保存的记忆ID,随机选择1-3个进行检索
|
||||
if self.memory_ids:
|
||||
num_to_retrieve = min(len(self.memory_ids), random.randint(1, 3))
|
||||
retrieval_ids = random.sample(self.memory_ids, num_to_retrieve)
|
||||
|
||||
logger.info(f"正在检索 {num_to_retrieve} 条记忆...")
|
||||
|
||||
for memory_id in retrieval_ids:
|
||||
memory = await self.working_memory.retrieve_memory(memory_id)
|
||||
if memory:
|
||||
logger.info(f"成功检索记忆 ID: {memory_id}")
|
||||
logger.info(f" - 强度: {memory.memory_strength:.2f},检索次数: {memory.retrieval_count}")
|
||||
if memory.summary:
|
||||
logger.info(f" - 主题: {memory.summary.get('brief', '无主题')}")
|
||||
else:
|
||||
logger.warning(f"记忆 ID: {memory_id} 不存在或已被移除")
|
||||
# 从ID列表中移除
|
||||
if memory_id in self.memory_ids:
|
||||
self.memory_ids.remove(memory_id)
|
||||
else:
|
||||
logger.info("当前没有可检索的记忆")
|
||||
|
||||
async def _add_new_memory(self):
|
||||
"""添加新记忆"""
|
||||
# 随机选择一个测试文件作为记忆内容
|
||||
file_path = random.choice(self.test_files)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
try:
|
||||
# 读取文件内容
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 添加时间戳,模拟不同内容
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
content_with_timestamp = f"[{timestamp}] {content}"
|
||||
|
||||
# 添加记忆
|
||||
logger.info(f"正在添加新记忆,来源: {file_name}")
|
||||
memory = await self.working_memory.add_memory(
|
||||
content=content_with_timestamp,
|
||||
from_source=f"模拟_{file_name}",
|
||||
tags=["模拟测试", f"循环{self.cycle_count}", file_name]
|
||||
)
|
||||
|
||||
# 保存记忆ID
|
||||
self.memory_ids.append(memory.id)
|
||||
|
||||
# 显示记忆信息
|
||||
logger.info(f"已添加新记忆 ID: {memory.id}")
|
||||
if memory.summary:
|
||||
logger.info(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
||||
logger.info(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点'])[:2])}...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加记忆失败: {str(e)}")
|
||||
|
||||
async def _show_memory_status(self):
|
||||
"""显示当前工作记忆状态"""
|
||||
all_memories = self.working_memory.memory_manager.get_all_items()
|
||||
|
||||
logger.info(f"\n当前工作记忆状态:")
|
||||
logger.info(f"记忆总数: {len(all_memories)}")
|
||||
|
||||
# 按强度排序
|
||||
sorted_memories = sorted(all_memories, key=lambda x: x.memory_strength, reverse=True)
|
||||
|
||||
logger.info("记忆强度排名 (前5项):")
|
||||
for i, memory in enumerate(sorted_memories[:5], 1):
|
||||
logger.info(f"{i}. ID: {memory.id}, 强度: {memory.memory_strength:.2f}, "
|
||||
f"检索次数: {memory.retrieval_count}, "
|
||||
f"主题: {memory.summary.get('brief', '无主题') if memory.summary else '无摘要'}")
|
||||
|
||||
def _get_test_files(self):
|
||||
"""获取测试文件列表"""
|
||||
test_dir = Path(__file__).parent
|
||||
return [
|
||||
os.path.join(test_dir, f)
|
||||
for f in os.listdir(test_dir)
|
||||
if f.endswith(".txt")
|
||||
]
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
# 创建模拟器
|
||||
simulator = WorkingMemorySimulator(cycle_interval=20) # 设置20秒的循环间隔
|
||||
|
||||
# 设置运行5个循环
|
||||
await simulator.start(total_cycles=5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,323 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到系统路径
|
||||
current_dir = Path(__file__).parent
|
||||
project_root = current_dir.parent.parent.parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.test.memory_file_loader import MemoryFileLoader
|
||||
from src.common.logger_manager import get_logger
|
||||
|
||||
logger = get_logger("memory_decay_test")
|
||||
|
||||
async def test_manual_decay_until_removal():
|
||||
"""测试手动衰减直到记忆被自动移除"""
|
||||
print("\n===== 测试手动衰减直到记忆被自动移除 =====")
|
||||
|
||||
# 初始化工作记忆,设置较大的衰减间隔,避免自动衰减影响测试
|
||||
chat_id = "decay_test_manual"
|
||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=3600)
|
||||
|
||||
try:
|
||||
# 创建加载器并加载测试文件
|
||||
loader = MemoryFileLoader(working_memory)
|
||||
test_dir = current_dir
|
||||
|
||||
# 加载第一个测试文件作为记忆
|
||||
memories = await loader.load_from_directory(
|
||||
directory_path=str(test_dir),
|
||||
file_pattern="test1.txt", # 只加载test1.txt
|
||||
common_tags=["测试", "衰减", "自动移除"],
|
||||
source_prefix="衰减测试"
|
||||
)
|
||||
|
||||
if not memories:
|
||||
print("未能加载记忆文件,测试结束")
|
||||
return
|
||||
|
||||
# 获取加载的记忆
|
||||
memory = memories[0]
|
||||
memory_id = memory.id
|
||||
print(f"已加载测试记忆,ID: {memory_id}")
|
||||
print(f"初始强度: {memory.memory_strength}")
|
||||
if memory.summary:
|
||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
||||
|
||||
# 执行多次衰减,直到记忆被移除
|
||||
decay_count = 0
|
||||
decay_factor = 0.5 # 每次衰减为原来的一半
|
||||
|
||||
while True:
|
||||
# 获取当前记忆
|
||||
current_memory = working_memory.memory_manager.get_by_id(memory_id)
|
||||
|
||||
# 如果记忆已被移除,退出循环
|
||||
if current_memory is None:
|
||||
print(f"记忆已在第 {decay_count} 次衰减后被自动移除!")
|
||||
break
|
||||
|
||||
# 输出当前强度
|
||||
print(f"衰减 {decay_count} 次后强度: {current_memory.memory_strength}")
|
||||
|
||||
# 执行衰减
|
||||
await working_memory.decay_all_memories(decay_factor=decay_factor)
|
||||
decay_count += 1
|
||||
|
||||
# 输出衰减后的详细信息
|
||||
after_memory = working_memory.memory_manager.get_by_id(memory_id)
|
||||
if after_memory:
|
||||
print(f"第 {decay_count} 次衰减结果: 强度={after_memory.memory_strength},压缩次数={after_memory.compress_count}")
|
||||
if after_memory.summary:
|
||||
print(f"记忆概要: {after_memory.summary.get('brief', '无概要')}")
|
||||
print(f"记忆要点数量: {len(after_memory.summary.get('key_points', []))}")
|
||||
else:
|
||||
print(f"第 {decay_count} 次衰减结果: 记忆已被移除")
|
||||
|
||||
# 防止无限循环
|
||||
if decay_count > 20:
|
||||
print("达到最大衰减次数(20),退出测试。")
|
||||
break
|
||||
|
||||
# 短暂等待
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 验证记忆是否真的被移除
|
||||
all_memories = working_memory.memory_manager.get_all_items()
|
||||
print(f"剩余记忆数量: {len(all_memories)}")
|
||||
if len(all_memories) == 0:
|
||||
print("测试通过: 记忆在强度低于阈值后被成功移除。")
|
||||
else:
|
||||
print("测试失败: 记忆应该被移除但仍然存在。")
|
||||
|
||||
finally:
|
||||
await working_memory.shutdown()
|
||||
|
||||
async def test_auto_decay():
|
||||
"""测试自动衰减功能"""
|
||||
print("\n===== 测试自动衰减功能 =====")
|
||||
|
||||
# 初始化工作记忆,设置短的衰减间隔,便于测试
|
||||
chat_id = "decay_test_auto"
|
||||
decay_interval = 3 # 3秒
|
||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=decay_interval)
|
||||
|
||||
try:
|
||||
# 创建加载器并加载测试文件
|
||||
loader = MemoryFileLoader(working_memory)
|
||||
test_dir = current_dir
|
||||
|
||||
# 加载第二个测试文件作为记忆
|
||||
memories = await loader.load_from_directory(
|
||||
directory_path=str(test_dir),
|
||||
file_pattern="test1.txt", # 只加载test2.txt
|
||||
common_tags=["测试", "自动衰减"],
|
||||
source_prefix="自动衰减测试"
|
||||
)
|
||||
|
||||
if not memories:
|
||||
print("未能加载记忆文件,测试结束")
|
||||
return
|
||||
|
||||
# 获取加载的记忆
|
||||
memory = memories[0]
|
||||
memory_id = memory.id
|
||||
print(f"已加载测试记忆,ID: {memory_id}")
|
||||
print(f"初始强度: {memory.memory_strength}")
|
||||
if memory.summary:
|
||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
||||
print(f"记忆概要: {memory.summary.get('detailed', '无概要')}")
|
||||
print(f"记忆要点: {memory.summary.get('keypoints', '无要点')}")
|
||||
print(f"记忆事件: {memory.summary.get('events', '无事件')}")
|
||||
# 观察自动衰减
|
||||
print(f"等待自动衰减任务执行 (间隔 {decay_interval} 秒)...")
|
||||
|
||||
for i in range(3): # 观察3次自动衰减
|
||||
# 等待自动衰减发生
|
||||
await asyncio.sleep(decay_interval + 1) # 多等1秒确保任务执行
|
||||
|
||||
# 获取当前记忆
|
||||
current_memory = working_memory.memory_manager.get_by_id(memory_id)
|
||||
|
||||
# 如果记忆已被移除,退出循环
|
||||
if current_memory is None:
|
||||
print(f"记忆已在第 {i+1} 次自动衰减后被移除!")
|
||||
break
|
||||
|
||||
# 输出当前强度和详细信息
|
||||
print(f"第 {i+1} 次自动衰减后强度: {current_memory.memory_strength}")
|
||||
print(f"自动衰减详细结果: 压缩次数={current_memory.compress_count}, 提取次数={current_memory.retrieval_count}")
|
||||
if current_memory.summary:
|
||||
print(f"记忆概要: {current_memory.summary.get('brief', '无概要')}")
|
||||
|
||||
print(f"\n自动衰减测试结束。")
|
||||
|
||||
# 验证自动衰减是否发生
|
||||
final_memory = working_memory.memory_manager.get_by_id(memory_id)
|
||||
if final_memory is None:
|
||||
print("记忆已被自动衰减移除。")
|
||||
elif final_memory.memory_strength < memory.memory_strength:
|
||||
print(f"自动衰减有效:初始强度 {memory.memory_strength} -> 最终强度 {final_memory.memory_strength}")
|
||||
print(f"衰减历史记录: {final_memory.history}")
|
||||
else:
|
||||
print("测试失败:记忆强度未减少,自动衰减可能未生效。")
|
||||
|
||||
finally:
|
||||
await working_memory.shutdown()
|
||||
|
||||
async def test_decay_and_retrieval_balance():
|
||||
"""测试记忆衰减和检索的平衡"""
|
||||
print("\n===== 测试记忆衰减和检索的平衡 =====")
|
||||
|
||||
# 初始化工作记忆
|
||||
chat_id = "decay_retrieval_balance"
|
||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
||||
|
||||
try:
|
||||
# 创建加载器并加载测试文件
|
||||
loader = MemoryFileLoader(working_memory)
|
||||
test_dir = current_dir
|
||||
|
||||
# 加载第三个测试文件作为记忆
|
||||
memories = await loader.load_from_directory(
|
||||
directory_path=str(test_dir),
|
||||
file_pattern="test3.txt", # 只加载test3.txt
|
||||
common_tags=["测试", "衰减", "检索"],
|
||||
source_prefix="平衡测试"
|
||||
)
|
||||
|
||||
if not memories:
|
||||
print("未能加载记忆文件,测试结束")
|
||||
return
|
||||
|
||||
# 获取加载的记忆
|
||||
memory = memories[0]
|
||||
memory_id = memory.id
|
||||
print(f"已加载测试记忆,ID: {memory_id}")
|
||||
print(f"初始强度: {memory.memory_strength}")
|
||||
if memory.summary:
|
||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
||||
|
||||
# 先衰减几次
|
||||
print("\n开始衰减:")
|
||||
for i in range(3):
|
||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
||||
if current:
|
||||
print(f"衰减 {i+1} 次后强度: {current.memory_strength}")
|
||||
print(f"衰减详细信息: 压缩次数={current.compress_count}, 历史操作数={len(current.history)}")
|
||||
if current.summary:
|
||||
print(f"记忆概要: {current.summary.get('brief', '无概要')}")
|
||||
else:
|
||||
print(f"记忆已在第 {i+1} 次衰减后被移除。")
|
||||
break
|
||||
|
||||
# 如果记忆还存在,则检索几次增强它
|
||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
||||
if current:
|
||||
print("\n开始检索增强:")
|
||||
for i in range(2):
|
||||
retrieved = await working_memory.retrieve_memory(memory_id)
|
||||
print(f"检索 {i+1} 次后强度: {retrieved.memory_strength}")
|
||||
print(f"检索后详细信息: 提取次数={retrieved.retrieval_count}, 历史记录长度={len(retrieved.history)}")
|
||||
|
||||
# 再次衰减几次,测试是否会被移除
|
||||
print("\n再次衰减:")
|
||||
for i in range(5):
|
||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
||||
if current:
|
||||
print(f"最终衰减 {i+1} 次后强度: {current.memory_strength}")
|
||||
print(f"衰减详细结果: 压缩次数={current.compress_count}")
|
||||
else:
|
||||
print(f"记忆已在最终衰减第 {i+1} 次后被移除。")
|
||||
break
|
||||
|
||||
print("\n测试结束。")
|
||||
|
||||
finally:
|
||||
await working_memory.shutdown()
|
||||
|
||||
async def test_multi_memories_decay():
|
||||
"""测试多条记忆同时衰减"""
|
||||
print("\n===== 测试多条记忆同时衰减 =====")
|
||||
|
||||
# 初始化工作记忆
|
||||
chat_id = "multi_decay_test"
|
||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
||||
|
||||
try:
|
||||
# 创建加载器并加载所有测试文件
|
||||
loader = MemoryFileLoader(working_memory)
|
||||
test_dir = current_dir
|
||||
|
||||
# 加载所有测试文件作为记忆
|
||||
memories = await loader.load_from_directory(
|
||||
directory_path=str(test_dir),
|
||||
file_pattern="*.txt",
|
||||
common_tags=["测试", "多记忆衰减"],
|
||||
source_prefix="多记忆测试"
|
||||
)
|
||||
|
||||
if not memories or len(memories) < 2:
|
||||
print("未能加载足够的记忆文件,测试结束")
|
||||
return
|
||||
|
||||
# 显示已加载的记忆
|
||||
print(f"已加载 {len(memories)} 条记忆:")
|
||||
for idx, mem in enumerate(memories):
|
||||
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 来源: {mem.from_source}")
|
||||
if mem.summary:
|
||||
print(f" 主题: {mem.summary.get('brief', '无主题')}")
|
||||
|
||||
# 进行多次衰减测试
|
||||
print("\n开始多记忆衰减测试:")
|
||||
for decay_round in range(5):
|
||||
# 执行衰减
|
||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
||||
|
||||
# 获取并显示所有记忆
|
||||
all_memories = working_memory.memory_manager.get_all_items()
|
||||
print(f"\n第 {decay_round+1} 次衰减后,剩余记忆数量: {len(all_memories)}")
|
||||
|
||||
for idx, mem in enumerate(all_memories):
|
||||
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 压缩次数: {mem.compress_count}")
|
||||
if mem.summary:
|
||||
print(f" 概要: {mem.summary.get('brief', '无概要')[:30]}...")
|
||||
|
||||
# 如果所有记忆都被移除,退出循环
|
||||
if not all_memories:
|
||||
print("所有记忆已被移除,测试结束。")
|
||||
break
|
||||
|
||||
# 等待一下
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
print("\n多记忆衰减测试结束。")
|
||||
|
||||
finally:
|
||||
await working_memory.shutdown()
|
||||
|
||||
async def main():
|
||||
"""运行所有测试"""
|
||||
# 测试手动衰减直到移除
|
||||
await test_manual_decay_until_removal()
|
||||
|
||||
# 测试自动衰减
|
||||
await test_auto_decay()
|
||||
|
||||
# 测试衰减和检索的平衡
|
||||
await test_decay_and_retrieval_balance()
|
||||
|
||||
# 测试多条记忆同时衰减
|
||||
await test_multi_memories_decay()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,121 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
import unittest
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
|
||||
class TestWorkingMemory(unittest.TestCase):
|
||||
"""工作记忆测试类"""
|
||||
|
||||
def setUp(self):
|
||||
"""测试前准备"""
|
||||
self.chat_id = "test_chat_123"
|
||||
self.working_memory = WorkingMemory(chat_id=self.chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
||||
self.test_dir = Path(__file__).parent
|
||||
|
||||
def tearDown(self):
|
||||
"""测试后清理"""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.working_memory.shutdown())
|
||||
|
||||
def test_init(self):
|
||||
"""测试初始化"""
|
||||
self.assertEqual(self.working_memory.max_memories_per_chat, 10)
|
||||
self.assertEqual(self.working_memory.auto_decay_interval, 60)
|
||||
|
||||
def test_add_memory_from_files(self):
|
||||
"""从文件添加记忆"""
|
||||
loop = asyncio.get_event_loop()
|
||||
test_files = self._get_test_files()
|
||||
|
||||
# 添加记忆
|
||||
memories = []
|
||||
for file_path in test_files:
|
||||
content = self._read_file_content(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
source = f"test_file_{file_name}"
|
||||
tags = ["测试", f"文件_{file_name}"]
|
||||
|
||||
memory = loop.run_until_complete(
|
||||
self.working_memory.add_memory(
|
||||
content=content,
|
||||
from_source=source,
|
||||
tags=tags
|
||||
)
|
||||
)
|
||||
memories.append(memory)
|
||||
|
||||
# 验证记忆数量
|
||||
all_items = self.working_memory.memory_manager.get_all_items()
|
||||
self.assertEqual(len(all_items), len(test_files))
|
||||
|
||||
# 验证每个记忆的内容和标签
|
||||
for i, memory in enumerate(memories):
|
||||
file_name = os.path.basename(test_files[i])
|
||||
retrieved_memory = loop.run_until_complete(
|
||||
self.working_memory.retrieve_memory(memory.id)
|
||||
)
|
||||
|
||||
self.assertIsNotNone(retrieved_memory)
|
||||
self.assertTrue(retrieved_memory.has_tag("测试"))
|
||||
self.assertTrue(retrieved_memory.has_tag(f"文件_{file_name}"))
|
||||
self.assertEqual(retrieved_memory.from_source, f"test_file_{file_name}")
|
||||
|
||||
# 验证检索后强度增加
|
||||
self.assertGreater(retrieved_memory.memory_strength, 10.0) # 原始强度为10.0,检索后增加1.5倍
|
||||
self.assertEqual(retrieved_memory.retrieval_count, 1)
|
||||
|
||||
def test_decay_memories(self):
|
||||
"""测试记忆衰减"""
|
||||
loop = asyncio.get_event_loop()
|
||||
test_files = self._get_test_files()[:1] # 只使用一个文件测试衰减
|
||||
|
||||
# 添加记忆
|
||||
for file_path in test_files:
|
||||
content = self._read_file_content(file_path)
|
||||
loop.run_until_complete(
|
||||
self.working_memory.add_memory(
|
||||
content=content,
|
||||
from_source="decay_test",
|
||||
tags=["衰减测试"]
|
||||
)
|
||||
)
|
||||
|
||||
# 获取添加后的记忆项
|
||||
all_items_before = self.working_memory.memory_manager.get_all_items()
|
||||
self.assertEqual(len(all_items_before), 1)
|
||||
|
||||
# 记录原始强度
|
||||
original_strength = all_items_before[0].memory_strength
|
||||
|
||||
# 执行衰减
|
||||
loop.run_until_complete(
|
||||
self.working_memory.decay_all_memories(decay_factor=0.5)
|
||||
)
|
||||
|
||||
# 获取衰减后的记忆项
|
||||
all_items_after = self.working_memory.memory_manager.get_all_items()
|
||||
|
||||
# 验证强度衰减
|
||||
self.assertEqual(len(all_items_after), 1)
|
||||
self.assertLess(all_items_after[0].memory_strength, original_strength)
|
||||
|
||||
def _get_test_files(self) -> List[str]:
|
||||
"""获取测试文件列表"""
|
||||
test_dir = self.test_dir
|
||||
return [
|
||||
os.path.join(test_dir, f)
|
||||
for f in os.listdir(test_dir)
|
||||
if f.endswith(".txt")
|
||||
]
|
||||
|
||||
def _read_file_content(self, file_path: str) -> str:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
from typing import List, Any, Optional
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
|
||||
@@ -9,39 +8,40 @@ logger = get_logger(__name__)
|
||||
|
||||
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
工作记忆,负责协调和运作记忆
|
||||
从属于特定的流,用chat_id来标识
|
||||
"""
|
||||
|
||||
def __init__(self,chat_id:str , max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
|
||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
"""
|
||||
初始化工作记忆管理器
|
||||
|
||||
|
||||
Args:
|
||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
||||
"""
|
||||
self.memory_manager = MemoryManager(chat_id)
|
||||
|
||||
|
||||
# 记忆容量上限
|
||||
self.max_memories_per_chat = max_memories_per_chat
|
||||
|
||||
|
||||
# 自动衰减间隔
|
||||
self.auto_decay_interval = auto_decay_interval
|
||||
|
||||
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
|
||||
|
||||
# 启动自动衰减任务
|
||||
self._start_auto_decay()
|
||||
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
if self.decay_task is None:
|
||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
||||
|
||||
|
||||
async def _auto_decay_loop(self):
|
||||
"""自动衰减循环"""
|
||||
while True:
|
||||
@@ -50,43 +50,39 @@ class WorkingMemory:
|
||||
await self.decay_all_memories()
|
||||
except Exception as e:
|
||||
print(f"自动衰减记忆时出错: {str(e)}")
|
||||
|
||||
|
||||
async def add_memory(self,
|
||||
content: Any,
|
||||
from_source: str = "",
|
||||
tags: Optional[List[str]] = None):
|
||||
|
||||
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
||||
"""
|
||||
添加一段记忆到指定聊天
|
||||
|
||||
|
||||
Args:
|
||||
content: 记忆内容
|
||||
from_source: 数据来源
|
||||
tags: 数据标签列表
|
||||
|
||||
|
||||
Returns:
|
||||
包含记忆信息的字典
|
||||
"""
|
||||
memory = await self.memory_manager.push_with_summary(content, from_source, tags)
|
||||
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
||||
self.remove_earliest_memory()
|
||||
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
def remove_earliest_memory(self):
|
||||
"""
|
||||
删除最早的记忆
|
||||
"""
|
||||
return self.memory_manager.delete_earliest_memory()
|
||||
|
||||
|
||||
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
检索记忆
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_id: 记忆ID
|
||||
|
||||
|
||||
Returns:
|
||||
检索到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
@@ -97,19 +93,18 @@ class WorkingMemory:
|
||||
return memory_item
|
||||
return None
|
||||
|
||||
|
||||
async def decay_all_memories(self, decay_factor: float = 0.5):
|
||||
"""
|
||||
对所有聊天的所有记忆进行衰减
|
||||
衰减:对记忆进行refine压缩,强度会变为原先的0.5
|
||||
|
||||
|
||||
Args:
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
"""
|
||||
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
|
||||
|
||||
|
||||
all_memories = self.memory_manager.get_all_items()
|
||||
|
||||
|
||||
for memory_item in all_memories:
|
||||
# 如果压缩完小于1会被删除
|
||||
memory_id = memory_item.id
|
||||
@@ -119,45 +114,47 @@ class WorkingMemory:
|
||||
continue
|
||||
# 计算衰减量
|
||||
if memory_item.memory_strength < 5:
|
||||
await self.memory_manager.refine_memory(memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩")
|
||||
|
||||
await self.memory_manager.refine_memory(
|
||||
memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
|
||||
)
|
||||
|
||||
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
||||
"""合并记忆
|
||||
|
||||
|
||||
Args:
|
||||
memory_str: 记忆内容
|
||||
"""
|
||||
return await self.memory_manager.merge_memories(memory_id1 = memory_id1, memory_id2 = memory_id2,reason = "两端记忆有重复的内容")
|
||||
|
||||
|
||||
|
||||
return await self.memory_manager.merge_memories(
|
||||
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
|
||||
)
|
||||
|
||||
# 暂时没用,先留着
|
||||
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
|
||||
"""
|
||||
模拟记忆模糊过程,随机选择一部分记忆进行精简
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简
|
||||
"""
|
||||
memory = self.get_memory(chat_id)
|
||||
|
||||
|
||||
# 获取所有字符串类型且有总结的记忆
|
||||
all_summarized_memories = []
|
||||
for type_items in memory._memory.values():
|
||||
for item in type_items:
|
||||
if isinstance(item.data, str) and hasattr(item, 'summary') and item.summary:
|
||||
if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
|
||||
all_summarized_memories.append(item)
|
||||
|
||||
|
||||
if not all_summarized_memories:
|
||||
return
|
||||
|
||||
|
||||
# 计算要模糊的记忆数量
|
||||
blur_count = max(1, int(len(all_summarized_memories) * blur_rate))
|
||||
|
||||
|
||||
# 随机选择要模糊的记忆
|
||||
memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories)))
|
||||
|
||||
|
||||
# 对选中的记忆进行精简
|
||||
for memory_item in memories_to_blur:
|
||||
try:
|
||||
@@ -168,16 +165,14 @@ class WorkingMemory:
|
||||
requirement = "保留核心要点,适度精简细节"
|
||||
else:
|
||||
requirement = "只保留最关键的1-2个要点,大幅精简内容"
|
||||
|
||||
|
||||
# 进行精简
|
||||
await memory.refine_memory(memory_item.id, requirement)
|
||||
print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
|
||||
|
||||
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭管理器,停止所有任务"""
|
||||
if self.decay_task and not self.decay_task.done():
|
||||
@@ -185,13 +180,13 @@ class WorkingMemory:
|
||||
try:
|
||||
await self.decay_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
pass
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""
|
||||
获取所有记忆项目
|
||||
|
||||
|
||||
Returns:
|
||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
||||
"""
|
||||
return self.memory_manager.get_all_items()
|
||||
return self.memory_manager.get_all_items()
|
||||
|
||||
Reference in New Issue
Block a user