better 更好的记忆抽取策略,并且移除了无用选项
This commit is contained in:
@@ -18,6 +18,7 @@ from ..chat.utils import (
|
||||
)
|
||||
from ..models.utils_model import LLM_request
|
||||
from src.common.logger import get_module_logger, LogConfig, MEMORY_STYLE_CONFIG
|
||||
from src.plugins.memory_system.sample_distribution import MemoryBuildScheduler
|
||||
|
||||
# 定义日志配置
|
||||
memory_config = LogConfig(
|
||||
@@ -195,19 +196,9 @@ class Hippocampus:
|
||||
return hash(f"{nodes[0]}:{nodes[1]}")
|
||||
|
||||
def random_get_msg_snippet(self, target_timestamp: float, chat_size: int, max_memorized_time_per_msg: int) -> list:
|
||||
"""随机抽取一段时间内的消息片段
|
||||
Args:
|
||||
- target_timestamp: 目标时间戳
|
||||
- chat_size: 抽取的消息数量
|
||||
- max_memorized_time_per_msg: 每条消息的最大记忆次数
|
||||
|
||||
Returns:
|
||||
- list: 抽取出的消息记录列表
|
||||
|
||||
"""
|
||||
try_count = 0
|
||||
# 最多尝试三次抽取
|
||||
while try_count < 3:
|
||||
# 最多尝试2次抽取
|
||||
while try_count < 2:
|
||||
messages = get_closest_chat_from_db(length=chat_size, timestamp=target_timestamp)
|
||||
if messages:
|
||||
# 检查messages是否均没有达到记忆次数限制
|
||||
@@ -224,54 +215,37 @@ class Hippocampus:
|
||||
)
|
||||
return messages
|
||||
try_count += 1
|
||||
# 三次尝试均失败
|
||||
return None
|
||||
|
||||
def get_memory_sample(self, chat_size=20, time_frequency=None):
|
||||
"""获取记忆样本
|
||||
|
||||
Returns:
|
||||
list: 消息记录列表,每个元素是一个消息记录字典列表
|
||||
"""
|
||||
def get_memory_sample(self):
|
||||
# 硬编码:每条消息最大记忆次数
|
||||
# 如有需求可写入global_config
|
||||
if time_frequency is None:
|
||||
time_frequency = {"near": 2, "mid": 4, "far": 3}
|
||||
max_memorized_time_per_msg = 3
|
||||
|
||||
current_timestamp = datetime.datetime.now().timestamp()
|
||||
# 创建双峰分布的记忆调度器
|
||||
scheduler = MemoryBuildScheduler(
|
||||
n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值(4小时前)
|
||||
std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差
|
||||
weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60%
|
||||
n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值(24小时前)
|
||||
std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差
|
||||
weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40%
|
||||
total_samples=global_config.build_memory_sample_num # 总共生成10个时间点
|
||||
)
|
||||
|
||||
# 生成时间戳数组
|
||||
timestamps = scheduler.get_timestamp_array()
|
||||
logger.debug(f"生成的时间戳数组: {timestamps}")
|
||||
|
||||
chat_samples = []
|
||||
|
||||
# 短期:1h 中期:4h 长期:24h
|
||||
logger.debug("正在抽取短期消息样本")
|
||||
for i in range(time_frequency.get("near")):
|
||||
random_time = current_timestamp - random.randint(1, 3600)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
for timestamp in timestamps:
|
||||
messages = self.random_get_msg_snippet(timestamp, global_config.build_memory_sample_length, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
logger.debug(f"成功抽取短期消息样本{len(messages)}条")
|
||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||
logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.warning(f"第{i}次短期消息样本抽取失败")
|
||||
|
||||
logger.debug("正在抽取中期消息样本")
|
||||
for i in range(time_frequency.get("mid")):
|
||||
random_time = current_timestamp - random.randint(3600, 3600 * 4)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
logger.debug(f"成功抽取中期消息样本{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.warning(f"第{i}次中期消息样本抽取失败")
|
||||
|
||||
logger.debug("正在抽取长期消息样本")
|
||||
for i in range(time_frequency.get("far")):
|
||||
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
|
||||
messages = self.random_get_msg_snippet(random_time, chat_size, max_memorized_time_per_msg)
|
||||
if messages:
|
||||
logger.debug(f"成功抽取长期消息样本{len(messages)}条")
|
||||
chat_samples.append(messages)
|
||||
else:
|
||||
logger.warning(f"第{i}次长期消息样本抽取失败")
|
||||
logger.warning(f"时间戳 {timestamp} 的消息样本抽取失败")
|
||||
|
||||
return chat_samples
|
||||
|
||||
@@ -372,9 +346,8 @@ class Hippocampus:
|
||||
)
|
||||
return topic_num
|
||||
|
||||
async def operation_build_memory(self, chat_size=20):
|
||||
time_frequency = {"near": 1, "mid": 4, "far": 4}
|
||||
memory_samples = self.get_memory_sample(chat_size, time_frequency)
|
||||
async def operation_build_memory(self):
|
||||
memory_samples = self.get_memory_sample()
|
||||
|
||||
for i, messages in enumerate(memory_samples, 1):
|
||||
all_topics = []
|
||||
|
||||
Reference in New Issue
Block a user