feat:记忆遗忘,优化记忆提取
This commit is contained in:
@@ -21,7 +21,6 @@ from src.person_info.person_info import Person
|
|||||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||||
from src.plugin_system.core import events_manager
|
from src.plugin_system.core import events_manager
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||||
from src.memory_system.Memory_chest import global_memory_chest
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
|
|||||||
@@ -63,14 +63,14 @@ class FrequencyControl:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if time.time() - self.last_frequency_adjust_time < 120 or len(msg_list) <= 5:
|
if time.time() - self.last_frequency_adjust_time < 160 or len(msg_list) <= 20:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
new_msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_frequency_adjust_time,
|
timestamp_start=self.last_frequency_adjust_time,
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
limit=5,
|
limit=20,
|
||||||
limit_mode="latest",
|
limit_mode="latest",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,15 +19,12 @@ from src.chat.planner_actions.action_manager import ActionManager
|
|||||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||||
from src.express.expression_learner import expression_learner_manager
|
from src.express.expression_learner import expression_learner_manager
|
||||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||||
from src.memory_system.question_maker import QuestionMaker
|
|
||||||
from src.memory_system.questions import global_conflict_tracker
|
|
||||||
from src.memory_system.curious import check_and_make_question
|
from src.memory_system.curious import check_and_make_question
|
||||||
from src.jargon import extract_and_store_jargon
|
from src.jargon import extract_and_store_jargon
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||||
from src.plugin_system.core import events_manager
|
from src.plugin_system.core import events_manager
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||||
from src.memory_system.Memory_chest import global_memory_chest
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ import re
|
|||||||
|
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.memory_system.Memory_chest import global_memory_chest
|
|
||||||
from src.memory_system.questions import global_conflict_tracker
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
|||||||
376
src/chat/utils/memory_forget_task.py
Normal file
376
src/chat/utils/memory_forget_task.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""
|
||||||
|
记忆遗忘任务
|
||||||
|
每5分钟进行一次遗忘检查,根据不同的遗忘阶段删除记忆
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database_model import ChatHistory
|
||||||
|
from src.manager.async_task_manager import AsyncTask
|
||||||
|
|
||||||
|
logger = get_logger("memory_forget_task")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryForgetTask(AsyncTask):
|
||||||
|
"""记忆遗忘任务,每5分钟执行一次"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 每5分钟执行一次(300秒)
|
||||||
|
super().__init__(task_name="Memory Forget Task", wait_before_start=0, run_interval=300)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""执行遗忘检查"""
|
||||||
|
try:
|
||||||
|
current_time = time.time()
|
||||||
|
logger.info("[记忆遗忘] 开始遗忘检查...")
|
||||||
|
|
||||||
|
# 执行4个阶段的遗忘检查
|
||||||
|
await self._forget_stage_1(current_time)
|
||||||
|
await self._forget_stage_2(current_time)
|
||||||
|
await self._forget_stage_3(current_time)
|
||||||
|
await self._forget_stage_4(current_time)
|
||||||
|
|
||||||
|
logger.info("[记忆遗忘] 遗忘检查完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘] 执行遗忘检查时出错: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _forget_stage_1(self, current_time: float):
|
||||||
|
"""
|
||||||
|
第一次遗忘检查:
|
||||||
|
搜集所有:记忆还未被遗忘检查过(forget_times=0),且已经是30分钟之外的记忆
|
||||||
|
取count最高25%和最低25%,删除,然后标记被遗忘检查次数为1
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 30分钟 = 1800秒
|
||||||
|
time_threshold = current_time - 1800
|
||||||
|
|
||||||
|
# 查询符合条件的记忆:forget_times=0 且 end_time < time_threshold
|
||||||
|
candidates = list(
|
||||||
|
ChatHistory.select()
|
||||||
|
.where(
|
||||||
|
(ChatHistory.forget_times == 0) &
|
||||||
|
(ChatHistory.end_time < time_threshold)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug("[记忆遗忘-阶段1] 没有符合条件的记忆")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段1] 找到 {len(candidates)} 条符合条件的记忆")
|
||||||
|
|
||||||
|
# 按count排序
|
||||||
|
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||||
|
|
||||||
|
# 计算要删除的数量(最高25%和最低25%)
|
||||||
|
total_count = len(candidates)
|
||||||
|
delete_count = int(total_count * 0.25) # 25%
|
||||||
|
|
||||||
|
if delete_count == 0:
|
||||||
|
logger.debug("[记忆遗忘-阶段1] 删除数量为0,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 选择要删除的记录(处理count相同的情况:随机选择)
|
||||||
|
to_delete = []
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||||
|
|
||||||
|
# 去重(避免重复删除),使用id去重
|
||||||
|
seen_ids = set()
|
||||||
|
unique_to_delete = []
|
||||||
|
for record in to_delete:
|
||||||
|
if record.id not in seen_ids:
|
||||||
|
seen_ids.add(record.id)
|
||||||
|
unique_to_delete.append(record)
|
||||||
|
to_delete = unique_to_delete
|
||||||
|
|
||||||
|
# 删除记录并更新forget_times
|
||||||
|
deleted_count = 0
|
||||||
|
for record in to_delete:
|
||||||
|
try:
|
||||||
|
record.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段1] 删除记录失败: {e}")
|
||||||
|
|
||||||
|
# 更新剩余记录的forget_times为1
|
||||||
|
to_delete_ids = {r.id for r in to_delete}
|
||||||
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
|
if remaining:
|
||||||
|
# 批量更新
|
||||||
|
ids_to_update = [r.id for r in remaining]
|
||||||
|
ChatHistory.update(forget_times=1).where(
|
||||||
|
ChatHistory.id.in_(ids_to_update)
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段1] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为1")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段1] 执行失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _forget_stage_2(self, current_time: float):
|
||||||
|
"""
|
||||||
|
第二次遗忘检查:
|
||||||
|
搜集所有:记忆遗忘检查为1,且已经是8小时之外的记忆
|
||||||
|
取count最高7%和最低7%,删除,然后标记被遗忘检查次数为2
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 8小时 = 28800秒
|
||||||
|
time_threshold = current_time - 28800
|
||||||
|
|
||||||
|
# 查询符合条件的记忆:forget_times=1 且 end_time < time_threshold
|
||||||
|
candidates = list(
|
||||||
|
ChatHistory.select()
|
||||||
|
.where(
|
||||||
|
(ChatHistory.forget_times == 1) &
|
||||||
|
(ChatHistory.end_time < time_threshold)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug("[记忆遗忘-阶段2] 没有符合条件的记忆")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段2] 找到 {len(candidates)} 条符合条件的记忆")
|
||||||
|
|
||||||
|
# 按count排序
|
||||||
|
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||||
|
|
||||||
|
# 计算要删除的数量(最高7%和最低7%)
|
||||||
|
total_count = len(candidates)
|
||||||
|
delete_count = int(total_count * 0.07) # 7%
|
||||||
|
|
||||||
|
if delete_count == 0:
|
||||||
|
logger.debug("[记忆遗忘-阶段2] 删除数量为0,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 选择要删除的记录
|
||||||
|
to_delete = []
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
to_delete = list(set(to_delete))
|
||||||
|
|
||||||
|
# 删除记录
|
||||||
|
deleted_count = 0
|
||||||
|
for record in to_delete:
|
||||||
|
try:
|
||||||
|
record.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段2] 删除记录失败: {e}")
|
||||||
|
|
||||||
|
# 更新剩余记录的forget_times为2
|
||||||
|
to_delete_ids = {r.id for r in to_delete}
|
||||||
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
|
if remaining:
|
||||||
|
ids_to_update = [r.id for r in remaining]
|
||||||
|
ChatHistory.update(forget_times=2).where(
|
||||||
|
ChatHistory.id.in_(ids_to_update)
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段2] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为2")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段2] 执行失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _forget_stage_3(self, current_time: float):
|
||||||
|
"""
|
||||||
|
第三次遗忘检查:
|
||||||
|
搜集所有:记忆遗忘检查为2,且已经是48小时之外的记忆
|
||||||
|
取count最高5%和最低5%,删除,然后标记被遗忘检查次数为3
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 48小时 = 172800秒
|
||||||
|
time_threshold = current_time - 172800
|
||||||
|
|
||||||
|
# 查询符合条件的记忆:forget_times=2 且 end_time < time_threshold
|
||||||
|
candidates = list(
|
||||||
|
ChatHistory.select()
|
||||||
|
.where(
|
||||||
|
(ChatHistory.forget_times == 2) &
|
||||||
|
(ChatHistory.end_time < time_threshold)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug("[记忆遗忘-阶段3] 没有符合条件的记忆")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段3] 找到 {len(candidates)} 条符合条件的记忆")
|
||||||
|
|
||||||
|
# 按count排序
|
||||||
|
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||||
|
|
||||||
|
# 计算要删除的数量(最高5%和最低5%)
|
||||||
|
total_count = len(candidates)
|
||||||
|
delete_count = int(total_count * 0.05) # 5%
|
||||||
|
|
||||||
|
if delete_count == 0:
|
||||||
|
logger.debug("[记忆遗忘-阶段3] 删除数量为0,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 选择要删除的记录
|
||||||
|
to_delete = []
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
to_delete = list(set(to_delete))
|
||||||
|
|
||||||
|
# 删除记录
|
||||||
|
deleted_count = 0
|
||||||
|
for record in to_delete:
|
||||||
|
try:
|
||||||
|
record.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段3] 删除记录失败: {e}")
|
||||||
|
|
||||||
|
# 更新剩余记录的forget_times为3
|
||||||
|
to_delete_ids = {r.id for r in to_delete}
|
||||||
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
|
if remaining:
|
||||||
|
ids_to_update = [r.id for r in remaining]
|
||||||
|
ChatHistory.update(forget_times=3).where(
|
||||||
|
ChatHistory.id.in_(ids_to_update)
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段3] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为3")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段3] 执行失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _forget_stage_4(self, current_time: float):
|
||||||
|
"""
|
||||||
|
第四次遗忘检查:
|
||||||
|
搜集所有:记忆遗忘检查为3,且已经是7天之外的记忆
|
||||||
|
取count最高2%和最低2%,删除,然后标记被遗忘检查次数为4
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 7天 = 604800秒
|
||||||
|
time_threshold = current_time - 604800
|
||||||
|
|
||||||
|
# 查询符合条件的记忆:forget_times=3 且 end_time < time_threshold
|
||||||
|
candidates = list(
|
||||||
|
ChatHistory.select()
|
||||||
|
.where(
|
||||||
|
(ChatHistory.forget_times == 3) &
|
||||||
|
(ChatHistory.end_time < time_threshold)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.debug("[记忆遗忘-阶段4] 没有符合条件的记忆")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段4] 找到 {len(candidates)} 条符合条件的记忆")
|
||||||
|
|
||||||
|
# 按count排序
|
||||||
|
candidates.sort(key=lambda x: x.count, reverse=True)
|
||||||
|
|
||||||
|
# 计算要删除的数量(最高2%和最低2%)
|
||||||
|
total_count = len(candidates)
|
||||||
|
delete_count = int(total_count * 0.02) # 2%
|
||||||
|
|
||||||
|
if delete_count == 0:
|
||||||
|
logger.debug("[记忆遗忘-阶段4] 删除数量为0,跳过")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 选择要删除的记录
|
||||||
|
to_delete = []
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "high"))
|
||||||
|
to_delete.extend(self._handle_same_count_random(candidates, delete_count, "low"))
|
||||||
|
|
||||||
|
# 去重
|
||||||
|
to_delete = list(set(to_delete))
|
||||||
|
|
||||||
|
# 删除记录
|
||||||
|
deleted_count = 0
|
||||||
|
for record in to_delete:
|
||||||
|
try:
|
||||||
|
record.delete_instance()
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段4] 删除记录失败: {e}")
|
||||||
|
|
||||||
|
# 更新剩余记录的forget_times为4
|
||||||
|
to_delete_ids = {r.id for r in to_delete}
|
||||||
|
remaining = [r for r in candidates if r.id not in to_delete_ids]
|
||||||
|
if remaining:
|
||||||
|
ids_to_update = [r.id for r in remaining]
|
||||||
|
ChatHistory.update(forget_times=4).where(
|
||||||
|
ChatHistory.id.in_(ids_to_update)
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
logger.info(f"[记忆遗忘-阶段4] 完成:删除了 {deleted_count} 条记忆,更新了 {len(remaining)} 条记忆的forget_times为4")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[记忆遗忘-阶段4] 执行失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
def _handle_same_count_random(self, candidates: List[ChatHistory], delete_count: int, mode: str) -> List[ChatHistory]:
|
||||||
|
"""
|
||||||
|
处理count相同的情况,随机选择要删除的记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candidates: 候选记录列表(已按count排序)
|
||||||
|
delete_count: 要删除的数量
|
||||||
|
mode: "high" 表示选择最高count的记录,"low" 表示选择最低count的记录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
要删除的记录列表
|
||||||
|
"""
|
||||||
|
if not candidates or delete_count == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
to_delete = []
|
||||||
|
|
||||||
|
if mode == "high":
|
||||||
|
# 从最高count开始选择
|
||||||
|
start_idx = 0
|
||||||
|
while start_idx < len(candidates) and len(to_delete) < delete_count:
|
||||||
|
# 找到所有count相同的记录
|
||||||
|
current_count = candidates[start_idx].count
|
||||||
|
same_count_records = []
|
||||||
|
idx = start_idx
|
||||||
|
while idx < len(candidates) and candidates[idx].count == current_count:
|
||||||
|
same_count_records.append(candidates[idx])
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||||
|
needed = delete_count - len(to_delete)
|
||||||
|
if len(same_count_records) <= needed:
|
||||||
|
to_delete.extend(same_count_records)
|
||||||
|
else:
|
||||||
|
# 随机选择需要的数量
|
||||||
|
to_delete.extend(random.sample(same_count_records, needed))
|
||||||
|
|
||||||
|
start_idx = idx
|
||||||
|
|
||||||
|
else: # mode == "low"
|
||||||
|
# 从最低count开始选择
|
||||||
|
start_idx = len(candidates) - 1
|
||||||
|
while start_idx >= 0 and len(to_delete) < delete_count:
|
||||||
|
# 找到所有count相同的记录
|
||||||
|
current_count = candidates[start_idx].count
|
||||||
|
same_count_records = []
|
||||||
|
idx = start_idx
|
||||||
|
while idx >= 0 and candidates[idx].count == current_count:
|
||||||
|
same_count_records.append(candidates[idx])
|
||||||
|
idx -= 1
|
||||||
|
|
||||||
|
# 如果相同count的记录数量 <= 还需要删除的数量,全部选择
|
||||||
|
needed = delete_count - len(to_delete)
|
||||||
|
if len(same_count_records) <= needed:
|
||||||
|
to_delete.extend(same_count_records)
|
||||||
|
else:
|
||||||
|
# 随机选择需要的数量
|
||||||
|
to_delete.extend(random.sample(same_count_records, needed))
|
||||||
|
|
||||||
|
start_idx = idx
|
||||||
|
|
||||||
|
return to_delete
|
||||||
|
|
||||||
@@ -317,35 +317,6 @@ class Expression(BaseModel):
|
|||||||
class Meta:
|
class Meta:
|
||||||
table_name = "expression"
|
table_name = "expression"
|
||||||
|
|
||||||
class MemoryChest(BaseModel):
|
|
||||||
"""
|
|
||||||
用于存储记忆仓库的模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
title = TextField() # 标题
|
|
||||||
content = TextField() # 内容
|
|
||||||
chat_id = TextField(null=True) # 聊天ID
|
|
||||||
locked = BooleanField(default=False) # 是否锁定
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
table_name = "memory_chest"
|
|
||||||
|
|
||||||
class MemoryConflict(BaseModel):
|
|
||||||
"""
|
|
||||||
用于存储记忆整合过程中冲突内容的模型
|
|
||||||
"""
|
|
||||||
|
|
||||||
conflict_content = TextField() # 冲突内容
|
|
||||||
answer = TextField(null=True) # 回答内容
|
|
||||||
create_time = FloatField() # 创建时间
|
|
||||||
update_time = FloatField() # 更新时间
|
|
||||||
context = TextField(null=True) # 上下文
|
|
||||||
chat_id = TextField(null=True) # 聊天ID
|
|
||||||
raise_time = FloatField(null=True) # 触发次数
|
|
||||||
|
|
||||||
class Meta:
|
|
||||||
table_name = "memory_conflicts"
|
|
||||||
|
|
||||||
class Jargon(BaseModel):
|
class Jargon(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储俚语的模型
|
用于存储俚语的模型
|
||||||
@@ -378,6 +349,7 @@ class ChatHistory(BaseModel):
|
|||||||
keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储
|
keywords = TextField() # 关键词:这段对话的关键词,JSON格式存储
|
||||||
summary = TextField() # 概括:对这段话的平文本概括
|
summary = TextField() # 概括:对这段话的平文本概括
|
||||||
count = IntegerField(default=0) # 被检索次数
|
count = IntegerField(default=0) # 被检索次数
|
||||||
|
forget_times = IntegerField(default=0) # 被遗忘检查的次数
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "chat_history"
|
table_name = "chat_history"
|
||||||
@@ -410,8 +382,6 @@ MODELS = [
|
|||||||
PersonInfo,
|
PersonInfo,
|
||||||
Expression,
|
Expression,
|
||||||
ActionRecords,
|
ActionRecords,
|
||||||
MemoryChest,
|
|
||||||
MemoryConflict,
|
|
||||||
Jargon,
|
Jargon,
|
||||||
ChatHistory,
|
ChatHistory,
|
||||||
ThinkingBack,
|
ThinkingBack,
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ class MainSystem:
|
|||||||
# 添加遥测心跳任务
|
# 添加遥测心跳任务
|
||||||
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
await async_task_manager.add_task(TelemetryHeartBeatTask())
|
||||||
|
|
||||||
|
# 添加记忆遗忘任务
|
||||||
|
from src.chat.utils.memory_forget_task import MemoryForgetTask
|
||||||
|
await async_task_manager.add_task(MemoryForgetTask())
|
||||||
|
|
||||||
# 启动API服务器
|
# 启动API服务器
|
||||||
# start_api_server()
|
# start_api_server()
|
||||||
# logger.info("API服务器启动成功")
|
# logger.info("API服务器启动成功")
|
||||||
|
|||||||
@@ -1,666 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
|
||||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.plugin_system.apis.message_api import build_readable_messages
|
|
||||||
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
|
|
||||||
from json_repair import repair_json
|
|
||||||
from src.memory_system.questions import global_conflict_tracker
|
|
||||||
|
|
||||||
|
|
||||||
from .memory_utils import (
|
|
||||||
find_best_matching_memory,
|
|
||||||
check_title_exists_fuzzy,
|
|
||||||
get_all_titles,
|
|
||||||
find_most_similar_memory_by_chat_id,
|
|
||||||
compute_merge_similarity_threshold
|
|
||||||
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = get_logger("memory")
|
|
||||||
|
|
||||||
class MemoryChest:
|
|
||||||
def __init__(self):
|
|
||||||
|
|
||||||
self.LLMRequest = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.utils_small,
|
|
||||||
request_type="memory_chest",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.LLMRequest_build = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.utils,
|
|
||||||
request_type="memory_chest_build",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
self.running_content_list = {} # {chat_id: {"content": running_content, "last_update_time": timestamp, "create_time": timestamp}}
|
|
||||||
self.fetched_memory_list = [] # [(chat_id, (question, answer, timestamp)), ...]
|
|
||||||
|
|
||||||
def remove_one_memory_by_age_weight(self) -> bool:
|
|
||||||
"""
|
|
||||||
删除一条记忆:按“越老/越新更易被删”的权重随机选择(老=较小id,新=较大id)。
|
|
||||||
|
|
||||||
返回:是否删除成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
memories = list(MemoryChestModel.select())
|
|
||||||
if not memories:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 排除锁定项
|
|
||||||
candidates = [m for m in memories if not getattr(m, "locked", False)]
|
|
||||||
if not candidates:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 按 id 排序,使用 id 近似时间顺序(小 -> 老,大 -> 新)
|
|
||||||
candidates.sort(key=lambda m: m.id)
|
|
||||||
n = len(candidates)
|
|
||||||
if n == 1:
|
|
||||||
MemoryChestModel.delete().where(MemoryChestModel.id == candidates[0].id).execute()
|
|
||||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{candidates[0].title}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 计算U型权重:中间最低,两端最高
|
|
||||||
# r ∈ [0,1] 为位置归一化,w = 0.1 + 0.9 * (abs(r-0.5)*2)**1.5
|
|
||||||
weights = []
|
|
||||||
for idx, _m in enumerate(candidates):
|
|
||||||
r = idx / (n - 1)
|
|
||||||
w = 0.1 + 0.9 * (abs(r - 0.5) * 2) ** 1.5
|
|
||||||
weights.append(w)
|
|
||||||
|
|
||||||
import random as _random
|
|
||||||
selected = _random.choices(candidates, weights=weights, k=1)[0]
|
|
||||||
|
|
||||||
MemoryChestModel.delete().where(MemoryChestModel.id == selected.id).execute()
|
|
||||||
logger.info(f"[记忆管理] 已删除一条记忆(权重抽样):{selected.title}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[记忆管理] 按年龄权重删除记忆时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
|
|
||||||
"""
|
|
||||||
根据问题获取答案
|
|
||||||
"""
|
|
||||||
logger.info(f"正在回忆问题答案: {question}")
|
|
||||||
|
|
||||||
title = await self.select_title_by_question(question)
|
|
||||||
|
|
||||||
if not title:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
for memory in MemoryChestModel.select():
|
|
||||||
if memory.title == title:
|
|
||||||
content = memory.content
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
|
||||||
type = "要求原文能够较为全面的回答问题"
|
|
||||||
else:
|
|
||||||
type = "要求提取简短的内容"
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
目标文段:
|
|
||||||
{content}
|
|
||||||
|
|
||||||
你现在需要从目标文段中找出合适的信息来回答问题:{question}
|
|
||||||
请务必从目标文段中提取相关信息的**原文**并输出,{type}
|
|
||||||
如果没有原文能够回答问题,输出"无有效信息"即可,不要输出其他内容:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"记忆仓库获取答案 prompt: {prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"记忆仓库获取答案 prompt: {prompt}")
|
|
||||||
|
|
||||||
answer, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
|
||||||
|
|
||||||
if "无有效" in answer or "无有效信息" in answer or "无信息" in answer:
|
|
||||||
logger.info(f"没有能够回答{question}的记忆")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
logger.info(f"记忆仓库对问题 “{question}” 获取答案: {answer}")
|
|
||||||
|
|
||||||
# 将问题和答案存到fetched_memory_list
|
|
||||||
if chat_id and answer:
|
|
||||||
self.fetched_memory_list.append((chat_id, (question, answer, time.time())))
|
|
||||||
|
|
||||||
# 清理fetched_memory_list
|
|
||||||
self._cleanup_fetched_memory_list()
|
|
||||||
|
|
||||||
return answer
|
|
||||||
|
|
||||||
def get_chat_memories_as_string(self, chat_id: str) -> str:
|
|
||||||
"""
|
|
||||||
获取某个chat_id的所有记忆,并构建成字符串
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 格式化的记忆字符串,格式:问题:xxx,答案:xxxxx\n问题:xxx,答案:xxxxx\n...
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
memories = []
|
|
||||||
|
|
||||||
# 从fetched_memory_list中获取该chat_id的所有记忆
|
|
||||||
for cid, (question, answer, timestamp) in self.fetched_memory_list:
|
|
||||||
if cid == chat_id:
|
|
||||||
memories.append(f"问题:{question},答案:{answer}")
|
|
||||||
|
|
||||||
# 按时间戳排序(最新的在后面)
|
|
||||||
memories.sort()
|
|
||||||
|
|
||||||
# 用换行符连接所有记忆
|
|
||||||
result = "\n".join(memories)
|
|
||||||
|
|
||||||
# logger.info(f"chat_id {chat_id} 共有 {len(memories)} 条记忆")
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取chat_id {chat_id} 的记忆时出错: {e}")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
async def select_title_by_question(self, question: str) -> str:
|
|
||||||
"""
|
|
||||||
根据消息内容选择最匹配的标题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: 问题
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 选择的标题
|
|
||||||
"""
|
|
||||||
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
|
|
||||||
titles = get_all_titles(exclude_locked=True)
|
|
||||||
formatted_titles = ""
|
|
||||||
for title in titles:
|
|
||||||
formatted_titles += f"{title}\n"
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
所有主题:
|
|
||||||
{formatted_titles}
|
|
||||||
|
|
||||||
请根据以下问题,选择一个能够回答问题的主题:
|
|
||||||
问题:{question}
|
|
||||||
请你输出主题,不要输出其他内容,完整输出主题名:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"记忆仓库选择标题 prompt: {prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"记忆仓库选择标题 prompt: {prompt}")
|
|
||||||
|
|
||||||
|
|
||||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
|
||||||
|
|
||||||
# 根据 title 获取 titles 里的对应项
|
|
||||||
selected_title = None
|
|
||||||
|
|
||||||
# 使用模糊查找匹配标题
|
|
||||||
best_match = find_best_matching_memory(title, similarity_threshold=0.8)
|
|
||||||
if best_match:
|
|
||||||
selected_title = best_match[0] # 获取匹配的标题
|
|
||||||
logger.info(f"记忆仓库选择标题: {selected_title} (相似度: {best_match[2]:.3f})")
|
|
||||||
else:
|
|
||||||
logger.warning(f"未找到相似度 >= 0.7 的标题匹配: {title}")
|
|
||||||
selected_title = None
|
|
||||||
|
|
||||||
return selected_title
|
|
||||||
|
|
||||||
def _cleanup_fetched_memory_list(self):
|
|
||||||
"""
|
|
||||||
清理fetched_memory_list,移除超过10分钟的记忆和超过10条的最旧记忆
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
current_time = time.time()
|
|
||||||
ten_minutes_ago = current_time - 600 # 10分钟 = 600秒
|
|
||||||
|
|
||||||
# 移除超过10分钟的记忆
|
|
||||||
self.fetched_memory_list = [
|
|
||||||
(chat_id, (question, answer, timestamp))
|
|
||||||
for chat_id, (question, answer, timestamp) in self.fetched_memory_list
|
|
||||||
if timestamp > ten_minutes_ago
|
|
||||||
]
|
|
||||||
|
|
||||||
# 如果记忆条数超过10条,移除最旧的5条
|
|
||||||
if len(self.fetched_memory_list) > 10:
|
|
||||||
# 按时间戳排序,移除最旧的5条
|
|
||||||
self.fetched_memory_list.sort(key=lambda x: x[1][2]) # 按timestamp排序
|
|
||||||
self.fetched_memory_list = self.fetched_memory_list[5:] # 保留最新的5条
|
|
||||||
|
|
||||||
logger.debug(f"fetched_memory_list清理后,当前有 {len(self.fetched_memory_list)} 条记忆")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理fetched_memory_list时出错: {e}")
|
|
||||||
|
|
||||||
async def _save_to_database_and_clear(self, chat_id: str, content: str):
|
|
||||||
"""
|
|
||||||
生成标题,保存到数据库,并清空对应chat_id的running_content
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
content: 要保存的内容
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 生成标题
|
|
||||||
title = ""
|
|
||||||
title_prompt = f"""
|
|
||||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
|
||||||
{content}
|
|
||||||
|
|
||||||
标题不要分点,不要换行,不要输出其他内容
|
|
||||||
请只输出标题,不要输出其他内容:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"记忆仓库生成标题 prompt: {title_prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"记忆仓库生成标题 prompt: {title_prompt}")
|
|
||||||
|
|
||||||
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(title_prompt)
|
|
||||||
|
|
||||||
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
if title:
|
|
||||||
# 保存到数据库
|
|
||||||
MemoryChestModel.create(
|
|
||||||
title=title.strip(),
|
|
||||||
content=content,
|
|
||||||
chat_id=chat_id
|
|
||||||
)
|
|
||||||
logger.info(f"已保存记忆仓库内容,标题: {title.strip()}, chat_id: {chat_id}")
|
|
||||||
|
|
||||||
# 清空内容并刷新时间戳,但保留条目用于增量计算
|
|
||||||
if chat_id in self.running_content_list:
|
|
||||||
current_time = time.time()
|
|
||||||
self.running_content_list[chat_id] = {
|
|
||||||
"content": "",
|
|
||||||
"last_update_time": current_time,
|
|
||||||
"create_time": current_time
|
|
||||||
}
|
|
||||||
logger.info(f"已保存并刷新chat_id {chat_id} 的时间戳,准备下一次增量构建")
|
|
||||||
else:
|
|
||||||
logger.warning(f"生成标题失败,chat_id: {chat_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存记忆仓库内容时出错: {e}")
|
|
||||||
|
|
||||||
async def choose_merge_target(self, memory_title: str, chat_id: str = None) -> tuple[list[str], list[str]]:
|
|
||||||
"""
|
|
||||||
选择与给定记忆标题相关的记忆目标(基于文本相似度)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
memory_title: 要匹配的记忆标题
|
|
||||||
chat_id: 聊天ID,用于筛选同chat_id的记忆
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[list[str], list[str]]: (选中的记忆标题列表, 选中的记忆内容列表)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not chat_id:
|
|
||||||
logger.warning("未提供chat_id,无法进行记忆匹配")
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
# 动态计算相似度阈值(占比越高阈值越低)
|
|
||||||
dynamic_threshold = compute_merge_similarity_threshold()
|
|
||||||
|
|
||||||
# 使用相似度匹配查找最相似的记忆(基于动态阈值)
|
|
||||||
similar_memory = find_most_similar_memory_by_chat_id(
|
|
||||||
target_title=memory_title,
|
|
||||||
target_chat_id=chat_id,
|
|
||||||
similarity_threshold=dynamic_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
if similar_memory:
|
|
||||||
selected_title, selected_content, similarity = similar_memory
|
|
||||||
logger.info(f"为 '{memory_title}' 找到相似记忆: '{selected_title}' (相似度: {similarity:.3f} 阈值: {dynamic_threshold:.2f})")
|
|
||||||
return [selected_title], [selected_content]
|
|
||||||
else:
|
|
||||||
logger.info(f"为 '{memory_title}' 未找到相似度 >= {dynamic_threshold:.2f} 的记忆")
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"选择合并目标时出错: {e}")
|
|
||||||
return [], []
|
|
||||||
|
|
||||||
def _get_memories_by_titles(self, titles: list[str]) -> list[str]:
|
|
||||||
"""
|
|
||||||
根据标题列表查找对应的记忆内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
titles: 记忆标题列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: 记忆内容列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
contents = []
|
|
||||||
for title in titles:
|
|
||||||
if not title or not title.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 使用模糊查找匹配记忆
|
|
||||||
try:
|
|
||||||
best_match = find_best_matching_memory(title.strip(), similarity_threshold=0.8)
|
|
||||||
if best_match:
|
|
||||||
# 检查记忆是否被锁定
|
|
||||||
memory_title = best_match[0]
|
|
||||||
memory_content = best_match[1]
|
|
||||||
|
|
||||||
# 查询数据库中的锁定状态
|
|
||||||
for memory in MemoryChestModel.select():
|
|
||||||
if memory.title == memory_title and memory.locked:
|
|
||||||
logger.warning(f"记忆 '{memory_title}' 已锁定,跳过合并")
|
|
||||||
continue
|
|
||||||
|
|
||||||
contents.append(memory_content)
|
|
||||||
logger.debug(f"找到记忆: {memory_title} (相似度: {best_match[2]:.3f})")
|
|
||||||
else:
|
|
||||||
logger.warning(f"未找到相似度 >= 0.8 的标题匹配: '{title}'")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"查找标题 '{title}' 的记忆时出错: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# logger.info(f"成功找到 {len(contents)} 条记忆内容")
|
|
||||||
return contents
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"根据标题查找记忆时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _parse_merged_parts(self, merged_response: str) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
解析合并记忆的part1和part2内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
merged_response: LLM返回的合并记忆响应
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str]: (part1_content, part2_content)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 使用正则表达式提取part1和part2内容
|
|
||||||
import re
|
|
||||||
|
|
||||||
# 提取part1内容
|
|
||||||
part1_pattern = r'<part1>(.*?)</part1>'
|
|
||||||
part1_match = re.search(part1_pattern, merged_response, re.DOTALL)
|
|
||||||
part1_content = part1_match.group(1).strip() if part1_match else ""
|
|
||||||
|
|
||||||
# 提取part2内容
|
|
||||||
part2_pattern = r'<part2>(.*?)</part2>'
|
|
||||||
part2_match = re.search(part2_pattern, merged_response, re.DOTALL)
|
|
||||||
part2_content = part2_match.group(1).strip() if part2_match else ""
|
|
||||||
|
|
||||||
# 检查是否包含none或None(不区分大小写)
|
|
||||||
def is_none_content(content: str) -> bool:
|
|
||||||
if not content:
|
|
||||||
return True
|
|
||||||
# 检查是否只包含"none"或"None"(不区分大小写)
|
|
||||||
return re.match(r'^\s*none\s*$', content, re.IGNORECASE) is not None
|
|
||||||
|
|
||||||
# 如果包含none,则设置为空字符串
|
|
||||||
if is_none_content(part1_content):
|
|
||||||
part1_content = ""
|
|
||||||
logger.info("part1内容为none,设置为空")
|
|
||||||
|
|
||||||
if is_none_content(part2_content):
|
|
||||||
part2_content = ""
|
|
||||||
logger.info("part2内容为none,设置为空")
|
|
||||||
|
|
||||||
return part1_content, part2_content
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析合并记忆part1/part2时出错: {e}")
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
def _parse_merge_target_json(self, json_text: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
解析choose_merge_target生成的JSON响应
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_text: LLM返回的JSON文本
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: 解析出的记忆标题列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 清理JSON文本,移除可能的额外内容
|
|
||||||
repaired_content = repair_json(json_text)
|
|
||||||
|
|
||||||
# 尝试直接解析JSON
|
|
||||||
try:
|
|
||||||
parsed_data = json.loads(repaired_content)
|
|
||||||
if isinstance(parsed_data, list):
|
|
||||||
# 如果是列表,提取selected_title字段
|
|
||||||
titles = []
|
|
||||||
for item in parsed_data:
|
|
||||||
if isinstance(item, dict) and "selected_title" in item:
|
|
||||||
value = item.get("selected_title", "")
|
|
||||||
if isinstance(value, str) and value.strip():
|
|
||||||
titles.append(value)
|
|
||||||
return titles
|
|
||||||
elif isinstance(parsed_data, dict) and "selected_title" in parsed_data:
|
|
||||||
# 如果是单个对象
|
|
||||||
value = parsed_data.get("selected_title", "")
|
|
||||||
if isinstance(value, str) and value.strip():
|
|
||||||
return [value]
|
|
||||||
else:
|
|
||||||
# 空字符串表示没有相关记忆
|
|
||||||
return []
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 如果直接解析失败,尝试提取JSON对象
|
|
||||||
# 查找所有包含selected_title的JSON对象
|
|
||||||
pattern = r'\{[^}]*"selected_title"[^}]*\}'
|
|
||||||
matches = re.findall(pattern, repaired_content)
|
|
||||||
|
|
||||||
titles = []
|
|
||||||
for match in matches:
|
|
||||||
try:
|
|
||||||
obj = json.loads(match)
|
|
||||||
if "selected_title" in obj:
|
|
||||||
value = obj.get("selected_title", "")
|
|
||||||
if isinstance(value, str) and value.strip():
|
|
||||||
titles.append(value)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if titles:
|
|
||||||
return titles
|
|
||||||
|
|
||||||
logger.warning(f"无法解析JSON响应: {json_text[:200]}...")
|
|
||||||
return []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"解析合并目标JSON时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def merge_memory(self,memory_list: list[str], chat_id: str = None) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
合并记忆
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 在记忆整合前先清理空chat_id的记忆
|
|
||||||
cleaned_count = self.cleanup_empty_chat_id_memories()
|
|
||||||
if cleaned_count > 0:
|
|
||||||
logger.info(f"记忆整合前清理了 {cleaned_count} 条空chat_id记忆")
|
|
||||||
|
|
||||||
content = ""
|
|
||||||
for memory in memory_list:
|
|
||||||
content += f"{memory}\n"
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
以下是多段记忆内容,请将它们进行整合和修改:
|
|
||||||
{content}
|
|
||||||
--------------------------------
|
|
||||||
请将上面的多段记忆内容,合并成两部分内容,第一部分是可以整合,不冲突的概念和知识,第二部分是相互有冲突的概念和知识
|
|
||||||
请主要关注概念和知识,而不是聊天的琐事
|
|
||||||
重要!!你要关注的概念和知识必须是较为不常见的信息,或者时效性较强的信息!!
|
|
||||||
不要!!关注常见的只是,或者已经过时的信息!!
|
|
||||||
1.不要关注诸如某个用户做了什么,说了什么,不要关注某个用户的行为,而是关注其中的概念性信息
|
|
||||||
2.概念要求精确,不啰嗦,像科普读物或教育课本那样
|
|
||||||
3.如果有图片,请只关注图片和文本结合的知识和概念性内容
|
|
||||||
4.记忆为一段纯文本,逻辑清晰,指出概念的含义,并说明关系
|
|
||||||
**第一部分**
|
|
||||||
1.如果两个概念在描述同一件事情,且相互之间逻辑不冲突(请你严格判断),且相互之间没有矛盾,请将它们整合成一个概念,并输出到第一部分
|
|
||||||
2.如果某个概念在时间上更新了另一个概念,请用新概念更新就概念来整合,并输出到第一部分
|
|
||||||
3.如果没有可整合的概念,请你输出none
|
|
||||||
**第二部分**
|
|
||||||
1.如果记忆中有无法整合的地方,例如概念不一致,有逻辑上的冲突,请你输出到第二部分
|
|
||||||
2.如果两个概念在描述同一件事情,但相互之间逻辑冲突,请将它们输出到第二部分
|
|
||||||
3.如果没有无法整合的概念,请你输出none
|
|
||||||
|
|
||||||
**输出格式要求**
|
|
||||||
请你按以下格式输出:
|
|
||||||
<part1>
|
|
||||||
第一部分内容,整合后的概念,如果第一部分为none,请输出none
|
|
||||||
</part1>
|
|
||||||
<part2>
|
|
||||||
第二部分内容,无法整合,冲突的概念,如果第二部分为none,请输出none
|
|
||||||
</part2>
|
|
||||||
不要输出其他内容,现在请你输出,不要输出其他内容,注意一定要直白,白话,口语化不要浮夸,修辞。:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"合并记忆 prompt: {prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"合并记忆 prompt: {prompt}")
|
|
||||||
|
|
||||||
merged_memory, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_build.generate_response_async(prompt)
|
|
||||||
|
|
||||||
# 解析part1和part2
|
|
||||||
part1_content, part2_content = self._parse_merged_parts(merged_memory)
|
|
||||||
|
|
||||||
# 处理part2:独立记录冲突内容(无论part1是否为空)
|
|
||||||
if part2_content and part2_content.strip() != "none":
|
|
||||||
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
|
|
||||||
# 记录冲突到数据库
|
|
||||||
await global_conflict_tracker.record_memory_merge_conflict(part2_content,chat_id)
|
|
||||||
|
|
||||||
# 处理part1:生成标题并保存
|
|
||||||
if part1_content and part1_content.strip() != "none":
|
|
||||||
merged_title = await self._generate_title_for_merged_memory(part1_content)
|
|
||||||
|
|
||||||
# 保存part1到数据库
|
|
||||||
MemoryChestModel.create(
|
|
||||||
title=merged_title,
|
|
||||||
content=part1_content,
|
|
||||||
chat_id=chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"合并记忆part1已保存: {merged_title}")
|
|
||||||
|
|
||||||
return merged_title, part1_content
|
|
||||||
else:
|
|
||||||
logger.warning("合并记忆part1为空,跳过保存")
|
|
||||||
return "", ""
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"合并记忆时出错: {e}")
|
|
||||||
return "", ""
|
|
||||||
|
|
||||||
async def _generate_title_for_merged_memory(self, merged_content: str) -> str:
|
|
||||||
"""
|
|
||||||
为合并后的记忆生成标题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
merged_content: 合并后的记忆内容
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 生成的标题
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
prompt = f"""
|
|
||||||
请为以下内容生成一个描述全面的标题,要求描述内容的主要概念和事件:
|
|
||||||
例如:
|
|
||||||
<example>
|
|
||||||
标题:达尔文的自然选择理论
|
|
||||||
内容:达尔文的自然选择是生物进化理论的重要组成部分,它解释了生物进化过程中的自然选择机制。
|
|
||||||
</example>
|
|
||||||
<example>
|
|
||||||
标题:麦麦的禁言插件和支持版本
|
|
||||||
内容:
|
|
||||||
麦麦的禁言插件是一款能够实现禁言的插件
|
|
||||||
麦麦的禁言插件可能不支持0.10.2
|
|
||||||
MutePlugin 是禁言插件的名称
|
|
||||||
</example>
|
|
||||||
|
|
||||||
|
|
||||||
需要对以下内容生成标题:
|
|
||||||
{merged_content}
|
|
||||||
|
|
||||||
|
|
||||||
标题不要分点,不要换行,不要输出其他内容,不要浮夸,以白话简洁的风格输出标题
|
|
||||||
请只输出标题,不要输出其他内容:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
|
||||||
logger.info(f"生成合并记忆标题 prompt: {prompt}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"生成合并记忆标题 prompt: {prompt}")
|
|
||||||
|
|
||||||
title_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
|
|
||||||
|
|
||||||
# 清理标题,移除可能的引号或多余字符
|
|
||||||
title = title_response.strip().strip('"').strip("'").strip()
|
|
||||||
|
|
||||||
if title:
|
|
||||||
# 检查是否存在相似标题
|
|
||||||
if check_title_exists_fuzzy(title, similarity_threshold=0.9):
|
|
||||||
logger.warning(f"生成的标题 '{title}' 与现有标题相似,使用时间戳后缀")
|
|
||||||
title = f"{title}_{int(time.time())}"
|
|
||||||
|
|
||||||
logger.info(f"生成合并记忆标题: {title}")
|
|
||||||
return title
|
|
||||||
else:
|
|
||||||
logger.warning("生成合并记忆标题失败,使用默认标题")
|
|
||||||
return f"合并记忆_{int(time.time())}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"生成合并记忆标题时出错: {e}")
|
|
||||||
return f"合并记忆_{int(time.time())}"
|
|
||||||
|
|
||||||
def cleanup_empty_chat_id_memories(self) -> int:
|
|
||||||
"""
|
|
||||||
清理chat_id为空的记忆记录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 被清理的记忆数量
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查找所有chat_id为空的记忆
|
|
||||||
empty_chat_id_memories = MemoryChestModel.select().where(
|
|
||||||
(MemoryChestModel.chat_id.is_null()) |
|
|
||||||
(MemoryChestModel.chat_id == "") |
|
|
||||||
(MemoryChestModel.chat_id == "None")
|
|
||||||
)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for memory in empty_chat_id_memories:
|
|
||||||
logger.info(f"清理空chat_id记忆: 标题='{memory.title}', ID={memory.id}")
|
|
||||||
memory.delete_instance()
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count > 0:
|
|
||||||
logger.info(f"已清理 {count} 条chat_id为空的记忆记录")
|
|
||||||
else:
|
|
||||||
logger.debug("未发现需要清理的空chat_id记忆记录")
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理空chat_id记忆时出错: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
global_memory_chest = MemoryChest()
|
|
||||||
@@ -7,7 +7,6 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
)
|
)
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.memory_system.questions import global_conflict_tracker
|
|
||||||
from src.memory_system.memory_utils import parse_md_json
|
from src.memory_system.memory_utils import parse_md_json
|
||||||
|
|
||||||
logger = get_logger("curious")
|
logger = get_logger("curious")
|
||||||
@@ -137,7 +136,7 @@ class CuriousDetector:
|
|||||||
|
|
||||||
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
|
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
|
||||||
"""
|
"""
|
||||||
将检测到的问题记录到冲突追踪器中
|
将检测到的问题记录(已移除冲突追踪器功能)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: 检测到的问题
|
question: 检测到的问题
|
||||||
@@ -150,14 +149,8 @@ class CuriousDetector:
|
|||||||
if not question or not question.strip():
|
if not question or not question.strip():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 记录问题到冲突追踪器
|
# 冲突追踪器功能已移除
|
||||||
await global_conflict_tracker.record_conflict(
|
logger.info(f"检测到问题(冲突追踪器已移除): {question}")
|
||||||
conflict_content=question.strip(),
|
|
||||||
context=context,
|
|
||||||
chat_id=self.chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"已记录问题到冲突追踪器: {question}")
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -67,7 +67,8 @@ def init_memory_retrieval_prompt():
|
|||||||
# 第二步:ReAct Agent prompt(工具描述会在运行时动态生成)
|
# 第二步:ReAct Agent prompt(工具描述会在运行时动态生成)
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你的名字是{bot_name},你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
|
你的名字是{bot_name}。现在是{time_now}。
|
||||||
|
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天。
|
||||||
你需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。
|
你需要通过思考(Think)、行动(Action)、观察(Observation)的循环来回答问题。
|
||||||
|
|
||||||
当前问题:{question}
|
当前问题:{question}
|
||||||
@@ -160,7 +161,7 @@ async def _react_agent_solve_question(
|
|||||||
chat_id: str,
|
chat_id: str,
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
timeout: float = 30.0
|
timeout: float = 30.0
|
||||||
) -> Tuple[bool, str, List[Dict[str, Any]]]:
|
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||||
"""使用ReAct架构的Agent来解决问题
|
"""使用ReAct架构的Agent来解决问题
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -170,16 +171,18 @@ async def _react_agent_solve_question(
|
|||||||
timeout: 超时时间(秒)
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str, List[Dict[str, Any]]]: (是否找到答案, 答案内容, 思考步骤列表)
|
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
collected_info = ""
|
collected_info = ""
|
||||||
thinking_steps = []
|
thinking_steps = []
|
||||||
|
is_timeout = False
|
||||||
|
|
||||||
for iteration in range(max_iterations):
|
for iteration in range(max_iterations):
|
||||||
# 检查超时
|
# 检查超时
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
||||||
|
is_timeout = True
|
||||||
break
|
break
|
||||||
|
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}")
|
||||||
@@ -191,10 +194,14 @@ async def _react_agent_solve_question(
|
|||||||
# 获取bot_name
|
# 获取bot_name
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
|
|
||||||
|
# 获取当前时间
|
||||||
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|
||||||
# 构建prompt(动态生成工具描述)
|
# 构建prompt(动态生成工具描述)
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"memory_retrieval_react_prompt",
|
"memory_retrieval_react_prompt",
|
||||||
bot_name=bot_name,
|
bot_name=bot_name,
|
||||||
|
time_now=time_now,
|
||||||
question=question,
|
question=question,
|
||||||
collected_info=collected_info if collected_info else "暂无信息",
|
collected_info=collected_info if collected_info else "暂无信息",
|
||||||
tools_description=tool_registry.get_tools_description(),
|
tools_description=tool_registry.get_tools_description(),
|
||||||
@@ -247,14 +254,14 @@ async def _react_agent_solve_question(
|
|||||||
step["observations"] = ["找到答案"]
|
step["observations"] = ["找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到最终答案: {answer}")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到最终答案: {answer}")
|
||||||
return True, answer, thinking_steps
|
return True, answer, thinking_steps, False
|
||||||
elif action_type == "no_answer":
|
elif action_type == "no_answer":
|
||||||
# Agent确认无法找到答案
|
# Agent确认无法找到答案
|
||||||
answer = thought # 使用thought说明无法找到答案的原因
|
answer = thought # 使用thought说明无法找到答案的原因
|
||||||
step["observations"] = ["确认无法找到答案"]
|
step["observations"] = ["确认无法找到答案"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}")
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}")
|
||||||
return False, answer, thinking_steps
|
return False, answer, thinking_steps, False
|
||||||
|
|
||||||
# 并行执行所有工具
|
# 并行执行所有工具
|
||||||
tool_registry = get_tool_registry()
|
tool_registry = get_tool_registry()
|
||||||
@@ -316,8 +323,11 @@ async def _react_agent_solve_question(
|
|||||||
# 只有Agent明确返回final_answer时,才认为找到了答案
|
# 只有Agent明确返回final_answer时,才认为找到了答案
|
||||||
if collected_info:
|
if collected_info:
|
||||||
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...")
|
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...")
|
||||||
logger.warning("ReAct Agent达到最大迭代次数或超时,直接视为no_answer")
|
if is_timeout:
|
||||||
return False, "未找到相关信息", thinking_steps
|
logger.warning("ReAct Agent超时,直接视为no_answer")
|
||||||
|
else:
|
||||||
|
logger.warning("ReAct Agent达到最大迭代次数,直接视为no_answer")
|
||||||
|
return False, "未找到相关信息", thinking_steps, is_timeout
|
||||||
|
|
||||||
|
|
||||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
|
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
|
||||||
@@ -513,28 +523,10 @@ def _store_thinking_back(
|
|||||||
logger.error(f"存储思考过程失败: {e}")
|
logger.error(f"存储思考过程失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _get_max_iterations_by_question_count(question_count: int) -> int:
|
|
||||||
"""根据问题数量获取最大迭代次数
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question_count: 问题数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 最大迭代次数
|
|
||||||
"""
|
|
||||||
if question_count == 1:
|
|
||||||
return 6
|
|
||||||
elif question_count == 2:
|
|
||||||
return 3
|
|
||||||
else: # 3个或以上
|
|
||||||
return 2
|
|
||||||
|
|
||||||
|
|
||||||
async def _process_single_question(
|
async def _process_single_question(
|
||||||
question: str,
|
question: str,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
context: str,
|
context: str
|
||||||
max_iterations: int
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||||
|
|
||||||
@@ -542,7 +534,6 @@ async def _process_single_question(
|
|||||||
question: 要查询的问题
|
question: 要查询的问题
|
||||||
chat_id: 聊天ID
|
chat_id: 聊天ID
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
max_iterations: 最大迭代次数
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None
|
Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None
|
||||||
@@ -584,22 +575,25 @@ async def _process_single_question(
|
|||||||
else:
|
else:
|
||||||
logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...")
|
logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||||
|
|
||||||
found_answer, answer, thinking_steps = await _react_agent_solve_question(
|
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||||
question=question,
|
question=question,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
max_iterations=max_iterations,
|
max_iterations=5,
|
||||||
timeout=30.0
|
timeout=120.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 存储到数据库
|
# 存储到数据库(超时时不存储)
|
||||||
_store_thinking_back(
|
if not is_timeout:
|
||||||
chat_id=chat_id,
|
_store_thinking_back(
|
||||||
question=question,
|
chat_id=chat_id,
|
||||||
context=context,
|
question=question,
|
||||||
found_answer=found_answer,
|
context=context,
|
||||||
answer=answer,
|
found_answer=found_answer,
|
||||||
thinking_steps=thinking_steps
|
answer=answer,
|
||||||
)
|
thinking_steps=thinking_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||||
|
|
||||||
if found_answer and answer:
|
if found_answer and answer:
|
||||||
return f"问题:{question}\n答案:{answer}"
|
return f"问题:{question}\n答案:{answer}"
|
||||||
@@ -659,8 +653,6 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
||||||
logger.info(f"记忆检索问题生成响应: {response}")
|
logger.info(f"记忆检索问题生成响应: {response}")
|
||||||
logger.info(f"记忆检索问题生成推理: {reasoning_content}")
|
|
||||||
logger.info(f"记忆检索问题生成模型: {model_name}")
|
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error(f"LLM生成问题失败: {response}")
|
logger.error(f"LLM生成问题失败: {response}")
|
||||||
@@ -679,23 +671,21 @@ async def build_memory_retrieval_prompt(
|
|||||||
retrieved_memory = "\n\n".join(cached_memories)
|
retrieved_memory = "\n\n".join(cached_memories)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆")
|
logger.info(f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆")
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n请在回复时参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||||
|
|
||||||
# 第二步:根据问题数量确定最大迭代次数
|
# 第二步:并行处理所有问题(固定使用5次迭代/120秒超时)
|
||||||
max_iterations = _get_max_iterations_by_question_count(len(questions))
|
logger.info(f"问题数量: {len(questions)},固定设置最大迭代次数: 5,超时时间: 120秒")
|
||||||
logger.info(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations}")
|
|
||||||
|
|
||||||
# 并行处理所有问题
|
# 并行处理所有问题
|
||||||
question_tasks = [
|
question_tasks = [
|
||||||
_process_single_question(
|
_process_single_question(
|
||||||
question=question,
|
question=question,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
context=message,
|
context=message
|
||||||
max_iterations=max_iterations
|
|
||||||
)
|
)
|
||||||
for question in questions
|
for question in questions
|
||||||
]
|
]
|
||||||
@@ -730,7 +720,7 @@ async def build_memory_retrieval_prompt(
|
|||||||
if all_results:
|
if all_results:
|
||||||
retrieved_memory = "\n\n".join(all_results)
|
retrieved_memory = "\n\n".join(all_results)
|
||||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n请在回复时参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -6,40 +6,12 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
|
|
||||||
from src.common.database.database_model import MemoryChest as MemoryChestModel
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from json_repair import repair_json
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_utils")
|
logger = get_logger("memory_utils")
|
||||||
|
|
||||||
def get_all_titles(exclude_locked: bool = False) -> list[str]:
|
|
||||||
"""
|
|
||||||
获取记忆仓库中的所有标题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exclude_locked: 是否排除锁定的记忆,默认为 False
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 包含所有标题的列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询所有记忆记录的标题
|
|
||||||
titles = []
|
|
||||||
for memory in MemoryChestModel.select():
|
|
||||||
if memory.title:
|
|
||||||
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
|
|
||||||
if exclude_locked and memory.locked:
|
|
||||||
continue
|
|
||||||
titles.append(memory.title)
|
|
||||||
return titles
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取记忆标题时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def parse_md_json(json_text: str) -> list[str]:
|
def parse_md_json(json_text: str) -> list[str]:
|
||||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||||
json_objects = []
|
json_objects = []
|
||||||
@@ -134,259 +106,3 @@ def preprocess_text(text: str) -> str:
|
|||||||
logger.error(f"预处理文本时出错: {e}")
|
logger.error(f"预处理文本时出错: {e}")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def fuzzy_find_memory_by_title(target_title: str, similarity_threshold: float = 0.9) -> List[Tuple[str, str, float]]:
|
|
||||||
"""
|
|
||||||
根据标题模糊查找记忆
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_title: 目标标题
|
|
||||||
similarity_threshold: 相似度阈值,默认0.9
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, str, float]]: 匹配的记忆列表,每个元素为(title, content, similarity_score)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取所有记忆
|
|
||||||
all_memories = MemoryChestModel.select()
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for memory in all_memories:
|
|
||||||
similarity = calculate_similarity(target_title, memory.title)
|
|
||||||
if similarity >= similarity_threshold:
|
|
||||||
matches.append((memory.title, memory.content, similarity))
|
|
||||||
|
|
||||||
# 按相似度降序排序
|
|
||||||
matches.sort(key=lambda x: x[2], reverse=True)
|
|
||||||
|
|
||||||
# logger.info(f"模糊查找标题 '{target_title}' 找到 {len(matches)} 个匹配项")
|
|
||||||
return matches
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模糊查找记忆时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def find_best_matching_memory(target_title: str, similarity_threshold: float = 0.9) -> Optional[Tuple[str, str, float]]:
|
|
||||||
"""
|
|
||||||
查找最佳匹配的记忆
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_title: 目标标题
|
|
||||||
similarity_threshold: 相似度阈值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Tuple[str, str, float]]: 最佳匹配的记忆(title, content, similarity)或None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
|
||||||
|
|
||||||
if matches:
|
|
||||||
best_match = matches[0] # 已经按相似度排序,第一个是最佳匹配
|
|
||||||
# logger.info(f"找到最佳匹配: '{best_match[0]}' (相似度: {best_match[2]:.3f})")
|
|
||||||
return best_match
|
|
||||||
else:
|
|
||||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"查找最佳匹配记忆时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def check_title_exists_fuzzy(target_title: str, similarity_threshold: float = 0.9) -> bool:
|
|
||||||
"""
|
|
||||||
检查标题是否已存在(模糊匹配)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_title: 目标标题
|
|
||||||
similarity_threshold: 相似度阈值,默认0.9(较高阈值避免误判)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否存在相似标题
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
matches = fuzzy_find_memory_by_title(target_title, similarity_threshold)
|
|
||||||
exists = len(matches) > 0
|
|
||||||
|
|
||||||
if exists:
|
|
||||||
logger.info(f"发现相似标题: '{matches[0][0]}' (相似度: {matches[0][2]:.3f})")
|
|
||||||
else:
|
|
||||||
logger.debug("未发现相似标题")
|
|
||||||
|
|
||||||
return exists
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"检查标题是否存在时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def get_memories_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[Tuple[str, str, str]]:
|
|
||||||
"""
|
|
||||||
根据chat_id进行加权抽样获取记忆列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_chat_id: 目标聊天ID
|
|
||||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
|
||||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Tuple[str, str, str]]: 选中的记忆列表,每个元素为(title, content, chat_id)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取所有记忆
|
|
||||||
all_memories = MemoryChestModel.select()
|
|
||||||
|
|
||||||
# 按chat_id分组
|
|
||||||
same_chat_memories = []
|
|
||||||
other_chat_memories = []
|
|
||||||
|
|
||||||
for memory in all_memories:
|
|
||||||
if memory.title and not memory.locked: # 排除锁定的记忆
|
|
||||||
if memory.chat_id == target_chat_id:
|
|
||||||
same_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
|
||||||
else:
|
|
||||||
other_chat_memories.append((memory.title, memory.content, memory.chat_id))
|
|
||||||
|
|
||||||
# 如果没有同chat_id的记忆,返回空列表
|
|
||||||
if not same_chat_memories:
|
|
||||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 计算抽样数量
|
|
||||||
total_same = len(same_chat_memories)
|
|
||||||
total_other = len(other_chat_memories)
|
|
||||||
|
|
||||||
# 根据权重计算抽样数量
|
|
||||||
if total_other > 0:
|
|
||||||
# 计算其他chat_id记忆的抽样数量(至少1个,最多不超过总数的10%)
|
|
||||||
other_sample_count = max(1, min(total_other, int(total_same * other_chat_weight / same_chat_weight)))
|
|
||||||
else:
|
|
||||||
other_sample_count = 0
|
|
||||||
|
|
||||||
# 随机抽样
|
|
||||||
selected_memories = []
|
|
||||||
|
|
||||||
# 选择同chat_id的记忆(全部选择,因为权重很高)
|
|
||||||
selected_memories.extend(same_chat_memories)
|
|
||||||
|
|
||||||
# 随机选择其他chat_id的记忆
|
|
||||||
if other_sample_count > 0 and total_other > 0:
|
|
||||||
import random
|
|
||||||
other_selected = random.sample(other_chat_memories, min(other_sample_count, total_other))
|
|
||||||
selected_memories.extend(other_selected)
|
|
||||||
|
|
||||||
logger.info(f"加权抽样结果: 同chat_id记忆 {len(same_chat_memories)} 条,其他chat_id记忆 {min(other_sample_count, total_other)} 条")
|
|
||||||
|
|
||||||
return selected_memories
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"按chat_id加权抽样记忆时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_memory_titles_by_chat_id_weighted(target_chat_id: str, same_chat_weight: float = 0.95, other_chat_weight: float = 0.05) -> List[str]:
|
|
||||||
"""
|
|
||||||
根据chat_id进行加权抽样获取记忆标题列表(用于合并选择)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_chat_id: 目标聊天ID
|
|
||||||
same_chat_weight: 同chat_id记忆的权重,默认0.95(95%概率)
|
|
||||||
other_chat_weight: 其他chat_id记忆的权重,默认0.05(5%概率)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: 选中的记忆标题列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
memories = get_memories_by_chat_id_weighted(target_chat_id, same_chat_weight, other_chat_weight)
|
|
||||||
titles = [memory[0] for memory in memories] # 提取标题
|
|
||||||
return titles
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"按chat_id加权抽样记忆标题时出错: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def find_most_similar_memory_by_chat_id(target_title: str, target_chat_id: str, similarity_threshold: float = 0.5) -> Optional[Tuple[str, str, float]]:
|
|
||||||
"""
|
|
||||||
在指定chat_id的记忆中查找最相似的记忆
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_title: 目标标题
|
|
||||||
target_chat_id: 目标聊天ID
|
|
||||||
similarity_threshold: 相似度阈值,默认0.7
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[Tuple[str, str, float]]: 最相似的记忆(title, content, similarity)或None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取指定chat_id的所有记忆
|
|
||||||
same_chat_memories = []
|
|
||||||
for memory in MemoryChestModel.select():
|
|
||||||
if memory.title and not memory.locked and memory.chat_id == target_chat_id:
|
|
||||||
same_chat_memories.append((memory.title, memory.content))
|
|
||||||
|
|
||||||
if not same_chat_memories:
|
|
||||||
logger.warning(f"未找到chat_id为 '{target_chat_id}' 的记忆")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 计算相似度并找到最佳匹配
|
|
||||||
best_match = None
|
|
||||||
best_similarity = 0.0
|
|
||||||
|
|
||||||
for title, content in same_chat_memories:
|
|
||||||
# 跳过目标标题本身
|
|
||||||
if title.strip() == target_title.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
similarity = calculate_similarity(target_title, title)
|
|
||||||
|
|
||||||
if similarity > best_similarity:
|
|
||||||
best_similarity = similarity
|
|
||||||
best_match = (title, content, similarity)
|
|
||||||
|
|
||||||
# 检查是否超过阈值
|
|
||||||
if best_match and best_similarity >= similarity_threshold:
|
|
||||||
logger.info(f"找到最相似记忆: '{best_match[0]}' (相似度: {best_similarity:.3f})")
|
|
||||||
return best_match
|
|
||||||
else:
|
|
||||||
logger.info(f"未找到相似度 >= {similarity_threshold} 的记忆,最高相似度: {best_similarity:.3f}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"查找最相似记忆时出错: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_merge_similarity_threshold() -> float:
|
|
||||||
"""
|
|
||||||
根据当前记忆数量占比动态计算合并相似度阈值。
|
|
||||||
|
|
||||||
规则:占比越高,阈值越低。
|
|
||||||
- < 60%: 0.80(更严格,避免早期误合并)
|
|
||||||
- < 80%: 0.70
|
|
||||||
- < 100%: 0.60
|
|
||||||
- < 120%: 0.50
|
|
||||||
- >= 120%: 0.45(最宽松,加速收敛)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
current_count = MemoryChestModel.select().count()
|
|
||||||
max_count = max(1, int(global_config.memory.max_memory_number))
|
|
||||||
percentage = current_count / max_count
|
|
||||||
|
|
||||||
if percentage < 0.6:
|
|
||||||
return 0.70
|
|
||||||
elif percentage < 0.8:
|
|
||||||
return 0.60
|
|
||||||
elif percentage < 1.0:
|
|
||||||
return 0.50
|
|
||||||
elif percentage < 1.5:
|
|
||||||
return 0.40
|
|
||||||
elif percentage < 2:
|
|
||||||
return 0.30
|
|
||||||
else:
|
|
||||||
return 0.25
|
|
||||||
except Exception:
|
|
||||||
# 发生异常时使用保守阈值
|
|
||||||
return 0.70
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
import time
|
|
||||||
import random
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
|
||||||
from src.common.database.database_model import MemoryConflict
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
class QuestionMaker:
|
|
||||||
def __init__(self, chat_id: str, context: str = "") -> None:
|
|
||||||
"""问题生成器。
|
|
||||||
|
|
||||||
- chat_id: 会话 ID,用于筛选该会话下的冲突记录。
|
|
||||||
- context: 额外上下文,可用于后续扩展。
|
|
||||||
|
|
||||||
用法示例:
|
|
||||||
>>> qm = QuestionMaker(chat_id="some_chat")
|
|
||||||
>>> question, chat_ctx, conflict_ctx = await qm.make_question()
|
|
||||||
"""
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.context = context
|
|
||||||
|
|
||||||
def get_context(self, timestamp: float = time.time()) -> str:
|
|
||||||
"""获取指定时间点之前的对话上下文字符串。"""
|
|
||||||
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=self.chat_id,
|
|
||||||
timestamp=timestamp,
|
|
||||||
limit=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
all_dialogue_prompt_str = build_readable_messages(
|
|
||||||
latest_30_msgs,
|
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal_no_YMD",
|
|
||||||
)
|
|
||||||
return all_dialogue_prompt_str
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_conflicts(self) -> List[MemoryConflict]:
|
|
||||||
"""获取当前会话下的所有记忆冲突记录。"""
|
|
||||||
conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
|
|
||||||
return conflicts
|
|
||||||
|
|
||||||
async def get_un_answered_conflict(self) -> List[MemoryConflict]:
|
|
||||||
"""获取未回答的记忆冲突记录(answer 为空)。"""
|
|
||||||
conflicts = await self.get_all_conflicts()
|
|
||||||
return [conflict for conflict in conflicts if not conflict.answer]
|
|
||||||
|
|
||||||
async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]:
|
|
||||||
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
|
||||||
|
|
||||||
选择规则:
|
|
||||||
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.01)。
|
|
||||||
- 若不存在,返回 None。
|
|
||||||
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
|
||||||
"""
|
|
||||||
conflicts = await self.get_un_answered_conflict()
|
|
||||||
if not conflicts:
|
|
||||||
return None
|
|
||||||
|
|
||||||
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
|
||||||
if conflicts_with_zero:
|
|
||||||
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.01
|
|
||||||
weights = []
|
|
||||||
for conflict in conflicts:
|
|
||||||
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
|
||||||
weight = 1.0 if current_raise_time == 0 else 0.01
|
|
||||||
weights.append(weight)
|
|
||||||
|
|
||||||
# 按权重随机选择
|
|
||||||
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
|
||||||
|
|
||||||
# 选中后,自增 raise_time 并保存
|
|
||||||
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
|
||||||
chosen_conflict.save()
|
|
||||||
|
|
||||||
return chosen_conflict
|
|
||||||
else:
|
|
||||||
# 如果没有 raise_time == 0 的冲突,返回 None
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
|
||||||
"""生成一条用于询问用户的冲突问题与上下文。
|
|
||||||
|
|
||||||
返回三元组 (question, chat_context, conflict_context):
|
|
||||||
- question: 冲突文本;若本次未选中任何冲突则为 None。
|
|
||||||
- chat_context: 该冲突创建时间点前的会话上下文字符串;若无则为 None。
|
|
||||||
- conflict_context: 冲突在 DB 中存储的上下文;若无则为 None。
|
|
||||||
"""
|
|
||||||
conflict = await self.get_random_unanswered_conflict()
|
|
||||||
if not conflict:
|
|
||||||
return None, None, None
|
|
||||||
question = conflict.conflict_content
|
|
||||||
conflict_context = conflict.context
|
|
||||||
create_time = conflict.create_time
|
|
||||||
chat_context = self.get_context(create_time)
|
|
||||||
|
|
||||||
return question, chat_context, conflict_context
|
|
||||||
@@ -1,192 +0,0 @@
|
|||||||
import time
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.database_model import MemoryConflict
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
|
||||||
from src.memory_system.memory_utils import parse_md_json
|
|
||||||
|
|
||||||
logger = get_logger("conflict_tracker")
|
|
||||||
|
|
||||||
class ConflictTracker:
|
|
||||||
"""
|
|
||||||
记忆整合冲突追踪器
|
|
||||||
|
|
||||||
用于记录和存储记忆整合过程中的冲突内容
|
|
||||||
"""
|
|
||||||
def __init__(self):
|
|
||||||
self.LLMRequest_tracker = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.utils,
|
|
||||||
request_type="conflict_tracker",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def record_conflict(self, conflict_content: str, context: str = "", chat_id: str = "") -> bool:
|
|
||||||
"""
|
|
||||||
记录冲突内容
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conflict_content: 冲突内容
|
|
||||||
context: 上下文
|
|
||||||
chat_id: 聊天ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功记录
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not conflict_content or conflict_content.strip() == "":
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 直接记录,不进行跟踪
|
|
||||||
MemoryConflict.create(
|
|
||||||
conflict_content=conflict_content,
|
|
||||||
create_time=time.time(),
|
|
||||||
update_time=time.time(),
|
|
||||||
answer="",
|
|
||||||
chat_id=chat_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"记录冲突内容时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def add_or_update_conflict(
|
|
||||||
self,
|
|
||||||
conflict_content: str,
|
|
||||||
create_time: float,
|
|
||||||
update_time: float,
|
|
||||||
answer: str = "",
|
|
||||||
context: str = "",
|
|
||||||
chat_id: str = None
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
根据conflict_content匹配数据库内容,如果找到相同的就更新update_time和answer,
|
|
||||||
如果没有相同的,就新建一条保存全部内容
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 尝试根据conflict_content查找现有记录
|
|
||||||
existing_conflict = MemoryConflict.get_or_none(
|
|
||||||
MemoryConflict.conflict_content == conflict_content,
|
|
||||||
MemoryConflict.chat_id == chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if existing_conflict:
|
|
||||||
# 如果找到相同的conflict_content,更新update_time和answer
|
|
||||||
existing_conflict.update_time = update_time
|
|
||||||
existing_conflict.answer = answer
|
|
||||||
existing_conflict.save()
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
# 如果没有找到相同的,创建新记录
|
|
||||||
MemoryConflict.create(
|
|
||||||
conflict_content=conflict_content,
|
|
||||||
create_time=create_time,
|
|
||||||
update_time=update_time,
|
|
||||||
answer=answer,
|
|
||||||
context=context,
|
|
||||||
chat_id=chat_id,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
# 记录错误并返回False
|
|
||||||
logger.error(f"添加或更新冲突记录时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def record_memory_merge_conflict(self, part2_content: str, chat_id: str = None) -> bool:
|
|
||||||
"""
|
|
||||||
记录记忆整合过程中的冲突内容(part2)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
part2_content: 冲突内容(part2)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功记录
|
|
||||||
"""
|
|
||||||
if not part2_content or part2_content.strip() == "":
|
|
||||||
return False
|
|
||||||
|
|
||||||
prompt = f"""以下是一段有冲突的信息,请你根据这些信息总结出几个具体的提问:
|
|
||||||
冲突信息:
|
|
||||||
{part2_content}
|
|
||||||
|
|
||||||
要求:
|
|
||||||
1.提问必须具体,明确
|
|
||||||
2.提问最好涉及指向明确的事物,而不是代称
|
|
||||||
3.如果缺少上下文,不要强行提问,可以忽略
|
|
||||||
4.请忽略涉及违法,暴力,色情,政治等敏感话题的内容
|
|
||||||
|
|
||||||
请用json格式输出,不要输出其他内容,仅输出提问理由和具体提的提问:
|
|
||||||
**示例**
|
|
||||||
// 理由文本
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"question":"提问",
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
```json
|
|
||||||
{{
|
|
||||||
"question":"提问"
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
...提问数量在1-3个之间,不要重复,现在请输出:"""
|
|
||||||
|
|
||||||
question_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_tracker.generate_response_async(prompt)
|
|
||||||
|
|
||||||
# 解析JSON响应
|
|
||||||
questions, reasoning_content = parse_md_json(question_response)
|
|
||||||
|
|
||||||
print(prompt)
|
|
||||||
print(question_response)
|
|
||||||
|
|
||||||
for question in questions:
|
|
||||||
await self.record_conflict(
|
|
||||||
conflict_content=question["question"],
|
|
||||||
context=reasoning_content,
|
|
||||||
chat_id=chat_id,
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def get_conflict_count(self) -> int:
|
|
||||||
"""
|
|
||||||
获取冲突记录数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 记录数量
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return MemoryConflict.select().count()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取冲突记录数量时出错: {e}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def delete_conflict(self, conflict_content: str, chat_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
删除指定的冲突记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
conflict_content: 冲突内容
|
|
||||||
chat_id: 聊天ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功删除
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
conflict = MemoryConflict.get_or_none(
|
|
||||||
MemoryConflict.conflict_content == conflict_content,
|
|
||||||
MemoryConflict.chat_id == chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if conflict:
|
|
||||||
conflict.delete_instance()
|
|
||||||
logger.info(f"已删除冲突记录: {conflict_content}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
logger.warning(f"未找到要删除的冲突记录: {conflict_content}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"删除冲突记录时出错: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 全局冲突追踪器实例
|
|
||||||
global_conflict_tracker = ConflictTracker()
|
|
||||||
@@ -11,15 +11,13 @@ logger = get_logger("memory_retrieval_tools")
|
|||||||
|
|
||||||
async def query_jargon(
|
async def query_jargon(
|
||||||
keyword: str,
|
keyword: str,
|
||||||
chat_id: str,
|
chat_id: str
|
||||||
fuzzy: bool = False
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""根据关键词在jargon库中查询
|
"""根据关键词在jargon库中查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keyword: 关键词(黑话/俚语/缩写)
|
keyword: 关键词(黑话/俚语/缩写)
|
||||||
chat_id: 聊天ID
|
chat_id: 聊天ID
|
||||||
fuzzy: 是否使用模糊搜索,默认False(精确匹配)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 查询结果
|
str: 查询结果
|
||||||
@@ -29,27 +27,53 @@ async def query_jargon(
|
|||||||
if not content:
|
if not content:
|
||||||
return "关键词为空"
|
return "关键词为空"
|
||||||
|
|
||||||
# 执行搜索(仅搜索当前会话或全局)
|
# 先尝试精确匹配
|
||||||
results = search_jargon(
|
results = search_jargon(
|
||||||
keyword=content,
|
keyword=content,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
limit=1,
|
limit=10,
|
||||||
case_sensitive=False,
|
case_sensitive=False,
|
||||||
fuzzy=fuzzy
|
fuzzy=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
is_fuzzy_match = False
|
||||||
|
|
||||||
|
# 如果精确匹配未找到,尝试模糊搜索
|
||||||
|
if not results:
|
||||||
|
results = search_jargon(
|
||||||
|
keyword=content,
|
||||||
|
chat_id=chat_id,
|
||||||
|
limit=10,
|
||||||
|
case_sensitive=False,
|
||||||
|
fuzzy=True
|
||||||
|
)
|
||||||
|
is_fuzzy_match = True
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
result = results[0]
|
# 如果是模糊匹配,显示找到的实际jargon内容
|
||||||
translation = result.get("translation", "").strip()
|
if is_fuzzy_match:
|
||||||
meaning = result.get("meaning", "").strip()
|
# 处理多个结果
|
||||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
output_parts = [f"未精确匹配到'{content}'"]
|
||||||
output = f'"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"'
|
for result in results:
|
||||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,{search_type}): {content}")
|
found_content = result.get("content", "").strip()
|
||||||
|
meaning = result.get("meaning", "").strip()
|
||||||
|
if found_content and meaning:
|
||||||
|
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||||
|
output = ",".join(output_parts)
|
||||||
|
logger.info(f"在jargon库中找到匹配(当前会话或全局,模糊搜索): {content},找到{len(results)}条结果")
|
||||||
|
else:
|
||||||
|
# 精确匹配,可能有多条(相同content但不同chat_id的情况)
|
||||||
|
output_parts = []
|
||||||
|
for result in results:
|
||||||
|
meaning = result.get("meaning", "").strip()
|
||||||
|
if meaning:
|
||||||
|
output_parts.append(f"'{content}' 为黑话或者网络简写,含义为:{meaning}")
|
||||||
|
output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0]
|
||||||
|
logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果")
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# 未命中
|
# 未命中
|
||||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}")
|
||||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,{search_type}): {content}")
|
|
||||||
return f"未在jargon库中找到'{content}'的解释"
|
return f"未在jargon库中找到'{content}'的解释"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -61,19 +85,13 @@ def register_tool():
|
|||||||
"""注册工具"""
|
"""注册工具"""
|
||||||
register_memory_retrieval_tool(
|
register_memory_retrieval_tool(
|
||||||
name="query_jargon",
|
name="query_jargon",
|
||||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。仅搜索当前会话或全局jargon。",
|
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||||
parameters=[
|
parameters=[
|
||||||
{
|
{
|
||||||
"name": "keyword",
|
"name": "keyword",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "关键词(黑话/俚语/缩写),支持模糊搜索",
|
"description": "关键词(黑话/俚语/缩写)",
|
||||||
"required": True
|
"required": True
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "fuzzy",
|
|
||||||
"type": "boolean",
|
|
||||||
"description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。",
|
|
||||||
"required": False
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
execute_func=query_jargon
|
execute_func=query_jargon
|
||||||
|
|||||||
Reference in New Issue
Block a user