feat:添加来自困惑内容的主动提问发言

This commit is contained in:
SengokuCola
2025-10-04 17:18:43 +08:00
parent eaf218ff15
commit 35c3fe5b10
16 changed files with 666 additions and 334 deletions

View File

@@ -12,9 +12,12 @@ from src.config.config import global_config
from src.plugin_system.apis.message_api import build_readable_messages
from src.plugin_system.apis.message_api import get_raw_msg_by_timestamp_with_chat
from json_repair import repair_json
from src.memory_system.questions import global_conflict_tracker
from .memory_utils import (
find_best_matching_memory,
check_title_exists_fuzzy
check_title_exists_fuzzy,
get_all_titles,
)
logger = get_logger("memory")
@@ -186,32 +189,6 @@ class MemoryChest:
return running_content
def get_all_titles(self, exclude_locked: bool = False) -> list[str]:
"""
获取记忆仓库中的所有标题
Args:
exclude_locked: 是否排除锁定的记忆,默认为 False
Returns:
list: 包含所有标题的列表
"""
try:
# 查询所有记忆记录的标题
titles = []
for memory in MemoryChestModel.select():
if memory.title:
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
if exclude_locked and memory.locked:
continue
titles.append(memory.title)
return titles
except Exception as e:
print(f"获取记忆标题时出错: {e}")
return []
async def get_answer_by_question(self, chat_id: str = "", question: str = "") -> str:
"""
根据问题获取答案
@@ -306,7 +283,7 @@ class MemoryChest:
str: 选择的标题
"""
# 获取所有标题并构建格式化字符串(排除锁定的记忆)
titles = self.get_all_titles(exclude_locked=True)
titles = get_all_titles(exclude_locked=True)
formatted_titles = ""
for title in titles:
formatted_titles += f"{title}\n"
@@ -329,7 +306,7 @@ class MemoryChest:
title, (reasoning_content, model_name, tool_calls) = await self.LLMRequest.generate_response_async(prompt)
# 根据 title 获取 titles 里的对应项
titles = self.get_all_titles()
titles = get_all_titles()
selected_title = None
# 使用模糊查找匹配标题
@@ -427,7 +404,7 @@ class MemoryChest:
list[str]: 选中的记忆内容列表
"""
try:
all_titles = self.get_all_titles(exclude_locked=True)
all_titles = get_all_titles(exclude_locked=True)
content = ""
display_index = 1
for title in all_titles:
@@ -462,7 +439,7 @@ class MemoryChest:
请输出JSON格式不要输出其他内容
"""
logger.info(f"选择合并目标 prompt: {prompt}")
# logger.info(f"选择合并目标 prompt: {prompt}")
if global_config.debug.show_prompt:
logger.info(f"选择合并目标 prompt: {prompt}")
@@ -692,8 +669,6 @@ class MemoryChest:
# 处理part2独立记录冲突内容无论part1是否为空
if part2_content and part2_content.strip() != "none":
logger.info(f"合并记忆part2记录冲突内容: {len(part2_content)} 字符")
# 导入冲突追踪器
from src.curiousity.questions import global_conflict_tracker
# 记录冲突到数据库
await global_conflict_tracker.record_memory_merge_conflict(part2_content)

View File

@@ -8,6 +8,7 @@ from src.memory_system.Memory_chest import global_memory_chest
from src.common.logger import get_logger
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.config.config import global_config
from src.memory_system.memory_utils import get_all_titles
logger = get_logger("memory")
@@ -130,7 +131,7 @@ class MemoryManagementTask(AsyncTask):
"""随机获取一个记忆标题"""
try:
# 获取所有记忆标题
all_titles = global_memory_chest.get_all_titles()
all_titles = get_all_titles()
if not all_titles:
return ""

View File

@@ -3,15 +3,79 @@
记忆系统工具函数
包含模糊查找、相似度计算等工具函数
"""
import json
import re
from difflib import SequenceMatcher
from typing import List, Tuple, Optional
from src.common.database.database_model import MemoryChest as MemoryChestModel
from src.common.logger import get_logger
from json_repair import repair_json
logger = get_logger("memory_utils")
def get_all_titles(exclude_locked: bool = False) -> list[str]:
"""
获取记忆仓库中的所有标题
Args:
exclude_locked: 是否排除锁定的记忆,默认为 False
Returns:
list: 包含所有标题的列表
"""
try:
# 查询所有记忆记录的标题
titles = []
for memory in MemoryChestModel.select():
if memory.title:
# 如果 exclude_locked 为 True 且记忆已锁定,则跳过
if exclude_locked and memory.locked:
continue
titles.append(memory.title)
return titles
except Exception as e:
print(f"获取记忆标题时出错: {e}")
return []
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
reasoning_content = ""
# 使用正则表达式查找```json包裹的JSON内容
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, json_text, re.DOTALL)
# 提取JSON之前的内容作为推理文本
if matches:
# 找到第一个```json的位置
first_json_pos = json_text.find("```json")
if first_json_pos > 0:
reasoning_content = json_text[:first_json_pos].strip()
# 清理推理内容中的注释标记
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
for match in matches:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
json_obj = json.loads(json_str)
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except Exception as e:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""

View File

@@ -0,0 +1,46 @@
import time
import random
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.common.database.database_model import MemoryConflict
from src.config.config import global_config
class QuestionMaker:
def __init__(self, chat_id: str,context: str = ""):
self.chat_id = chat_id
self.context = context
def get_context(self):
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
chat_id=self.chat_id,
timestamp=time.time(),
limit=30,
)
all_dialogue_prompt_str = build_readable_messages(
latest_30_msgs,
replace_bot_name=True,
timestamp_mode="normal_no_YMD",
)
return all_dialogue_prompt_str
async def get_all_conflicts(self):
conflicts = list(MemoryConflict.select())
return conflicts
async def get_un_answered_conflict(self):
conflicts = await self.get_all_conflicts()
return [conflict for conflict in conflicts if not conflict.answer]
async def get_random_unanswered_conflict(self):
conflicts = await self.get_un_answered_conflict()
return random.choice(conflicts)
async def make_question(self):
conflict = await self.get_random_unanswered_conflict()
question = conflict.conflict_content
conflict_context = conflict.context
chat_context = self.get_context()
return question, conflict_context

View File

@@ -0,0 +1,406 @@
import time
import asyncio
from src.common.logger import get_logger
from src.common.database.database_model import MemoryConflict
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from typing import List
from src.memory_system.memory_utils import parse_md_json
logger = get_logger("conflict_tracker")
class QuestionTracker:
"""
用于跟踪一个问题在后续聊天中的解答情况
"""
def __init__(self, question: str, chat_id: str, context: str = "") -> None:
self.question = question
self.chat_id = chat_id
now = time.time()
self.context = context
self.start_time = now
self.last_read_time = now
self.last_judge_time = now # 上次判定的时间
self.judge_debounce_interval = 10.0 # 判定防抖间隔10秒
self.consecutive_end_count = 0 # 连续END计数
self.active = True
# 将 LLM 实例作为类属性,使用 utils 模型
self.llm_request = LLMRequest(model_set=model_config.model_task_config.utils, request_type="conflict.judge")
def stop(self) -> None:
self.active = False
def should_judge_now(self) -> bool:
"""
检查是否应该进行判定(防抖检查)
Returns:
bool: 是否可以判定
"""
now = time.time()
# 检查是否已经过了10秒的防抖间隔
return (now - self.last_judge_time) >= self.judge_debounce_interval
def __eq__(self, other) -> bool:
"""比较两个追踪器是否相等基于问题内容和聊天ID"""
if not isinstance(other, QuestionTracker):
return False
return self.question == other.question and self.chat_id == other.chat_id
def __hash__(self) -> int:
"""为对象提供哈希值,支持集合操作"""
return hash((self.question, self.chat_id))
async def judge_answer(self, conversation_text: str) -> tuple[bool, str, str]:
"""
使用模型判定问题是否已得到解答。
Returns:
tuple[bool, str, str]: (是否结束跟踪, 结束原因或答案, 判定类型)
- True: 结束跟踪(已解答、话题转向等)
- False: 继续跟踪
判定类型: "ANSWERED", "END", "CONTINUE"
"""
prompt = f"""你是一个严谨的判定器。下面给出聊天记录以及一个问题。
任务:判断在这段聊天中,该问题是否已经得到明确解答。
**你必须严格按照聊天记录的内容,不要添加额外的信息**
输出规则:
- 如果聊天记录内容的信息已解答问题请只输出YES: <简短答案>
- 如果聊天记录内容与问题无关话题已转向其他方向请只输出END
- 如果问题尚未解答但聊天仍在相关话题上请只输出NO
**问题**
{self.question}
**聊天记录**
{conversation_text}
"""
if global_config.debug.show_prompt:
logger.info(f"判定提示词: {prompt}")
else:
logger.debug("已发送判定提示词")
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.5)
logger.info(f"判定结果: {prompt}\n{result_text}")
# 更新上次判定时间
self.last_judge_time = time.time()
if not result_text:
return False, "", "CONTINUE"
text = result_text.strip()
if text.upper().startswith("YES:"):
answer = text[4:].strip()
return True, answer, "ANSWERED"
if text.upper().startswith("YES"):
# 兼容仅输出 YES 或 YES <answer>
answer = text[3:].strip().lstrip(":").strip()
return True, answer, "ANSWERED"
if text.upper().startswith("END"):
# 聊天内容与问题无关,放弃该问题思考
return True, "话题已转向其他方向,放弃该问题思考", "END"
return False, "", "CONTINUE"
class ConflictTracker:
"""
记忆整合冲突追踪器
用于记录和存储记忆整合过程中的冲突内容
"""
def __init__(self):
self.question_tracker_list:List[QuestionTracker] = []
self.LLMRequest_tracker = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="conflict_tracker",
)
def get_questions_by_chat_id(self, chat_id: str) -> List[QuestionTracker]:
return [tracker for tracker in self.question_tracker_list if tracker.chat_id == chat_id]
async def track_conflict(self, question: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
"""
跟踪冲突内容
"""
tracker = QuestionTracker(question.strip(), chat_id, context)
self.question_tracker_list.append(tracker)
asyncio.create_task(self._follow_and_record(tracker, question.strip()))
return True
async def record_conflict(self, conflict_content: str, context: str = "",start_following: bool = False,chat_id: str = "") -> bool:
"""
记录冲突内容
Args:k
conflict_content: 冲突内容
Returns:
bool: 是否成功记录
"""
try:
if not conflict_content or conflict_content.strip() == "":
return False
# 若需要跟随后续消息以判断是否得到解答,则进入跟踪流程
if start_following and chat_id:
tracker = QuestionTracker(conflict_content.strip(), chat_id, context)
self.question_tracker_list.append(tracker)
# 后台启动跟踪任务,避免阻塞
asyncio.create_task(self._follow_and_record(tracker, conflict_content.strip()))
return True
# 默认:直接记录,不进行跟踪
MemoryConflict.create(
conflict_content=conflict_content,
create_time=time.time(),
update_time=time.time(),
answer="",
)
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
return True
except Exception as e:
logger.error(f"记录冲突内容时出错: {e}")
return False
async def _follow_and_record(self, tracker: QuestionTracker, original_question: str) -> None:
"""
后台任务:跟踪问题是否被解答,并写入数据库。
"""
try:
max_duration = 10 * 60 # 30 分钟
max_messages = 50 # 最多 100 条消息
poll_interval = 2.0 # 秒
logger.info(f"开始跟踪问题: {original_question}")
while tracker.active:
now_ts = time.time()
# 终止条件:时长达到上限
if now_ts - tracker.start_time >= max_duration:
logger.info("问题跟踪达到10分钟上限判定为未解答")
break
# 统计最近一段是否有新消息(不过滤机器人,过滤命令)
recent_msgs = get_raw_msg_by_timestamp_with_chat(
chat_id=tracker.chat_id,
timestamp_start=tracker.last_read_time,
timestamp_end=now_ts,
limit=30,
limit_mode="latest",
filter_bot=False,
filter_command=True,
)
if len(recent_msgs) > 0:
tracker.last_read_time = now_ts
# 统计从开始到现在的总消息数用于触发100条上限
all_msgs = get_raw_msg_by_timestamp_with_chat(
chat_id=tracker.chat_id,
timestamp_start=tracker.start_time,
timestamp_end=now_ts,
limit=0,
limit_mode="latest",
filter_bot=False,
filter_command=True,
)
# 检查是否应该进行判定(防抖检查)
if not tracker.should_judge_now():
logger.debug(f"判定防抖中,跳过本次判定: {tracker.question}")
await asyncio.sleep(poll_interval)
continue
# 构建可读聊天文本
chat_text = build_readable_messages(
all_msgs,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=False,
show_actions=False,
show_pic=False,
remove_emoji_stickers=True,
)
# 让小模型判断是否有答案
answered, answer_text, judge_type = await tracker.judge_answer(chat_text)
if judge_type == "ANSWERED":
# 问题已解答,直接结束跟踪
logger.info("问题已得到解答,结束跟踪并写入答案")
await self.add_or_update_conflict(
conflict_content=tracker.question,
create_time=tracker.start_time,
update_time=time.time(),
answer=answer_text or "",
)
return
elif judge_type == "END":
# 话题转向增加END计数
tracker.consecutive_end_count += 1
logger.info(f"话题已转向连续END次数: {tracker.consecutive_end_count}")
if tracker.consecutive_end_count >= 2:
# 连续两次END结束跟踪
logger.info("连续两次END结束跟踪")
break
else:
# 第一次END重置计数器并继续跟踪
logger.info("第一次END继续跟踪")
continue
elif judge_type == "CONTINUE":
# 继续跟踪重置END计数器
tracker.consecutive_end_count = 0
continue
if len(all_msgs) >= max_messages:
logger.info("问题跟踪达到100条消息上限判定为未解答")
logger.info(f"追踪结束:{tracker.question}")
break
# 无新消息时稍作等待
await asyncio.sleep(poll_interval)
# 未获取到答案,仅存储问题
await self.add_or_update_conflict(
conflict_content=original_question,
create_time=time.time(),
update_time=time.time(),
answer="",
)
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
logger.info(f"问题跟踪结束:{original_question}")
except Exception as e:
logger.error(f"后台问题跟踪任务异常: {e}")
finally:
# 无论任务成功还是失败,都要从追踪列表中移除
tracker.stop()
self.remove_tracker(tracker)
def remove_tracker(self, tracker: QuestionTracker) -> None:
"""
从追踪列表中移除指定的追踪器
Args:
tracker: 要移除的追踪器对象
"""
try:
if tracker in self.question_tracker_list:
self.question_tracker_list.remove(tracker)
logger.info(f"已从追踪列表中移除追踪器: {tracker.question}")
else:
logger.warning(f"尝试移除不存在的追踪器: {tracker.question}")
except Exception as e:
logger.error(f"移除追踪器时出错: {e}")
async def add_or_update_conflict(self,conflict_content: str,create_time: float,update_time: float,answer: str = "",context: str = "") -> bool:
"""
根据conflict_content匹配数据库内容如果找到相同的就更新update_time和answer
如果没有相同的,就新建一条保存全部内容
"""
try:
# 尝试根据conflict_content查找现有记录
existing_conflict = MemoryConflict.get_or_none(
MemoryConflict.conflict_content == conflict_content
)
if existing_conflict:
# 如果找到相同的conflict_content更新update_time和answer
existing_conflict.update_time = update_time
existing_conflict.answer = answer
existing_conflict.save()
return True
else:
# 如果没有找到相同的,创建新记录
MemoryConflict.create(
conflict_content=conflict_content,
create_time=create_time,
update_time=update_time,
answer=answer,
context=context,
)
return True
except Exception as e:
# 记录错误并返回False
logger.error(f"添加或更新冲突记录时出错: {e}")
return False
async def record_memory_merge_conflict(self, part2_content: str) -> bool:
"""
记录记忆整合过程中的冲突内容part2
Args:
part2_content: 冲突内容part2
Returns:
bool: 是否成功记录
"""
if not part2_content or part2_content.strip() == "":
return False
prompt = f"""以下是一段有冲突的信息,请你根据这些信息总结出几个具体的提问:
冲突信息:
{part2_content}
要求:
1.提问必须具体,明确
2.提问最好涉及指向明确的事物,而不是代称
3.如果缺少上下文,不要强行提问,可以忽略
请用json格式输出不要输出其他内容仅输出提问理由和具体提的提问:
**示例**
// 理由文本
```json
{{
"question":"提问",
}}
```
```json
{{
"question":"提问"
}}
```
...提问数量在1-3个之间不要重复现在请输出"""
question_response, (reasoning_content, model_name, tool_calls) = await self.LLMRequest_tracker.generate_response_async(prompt)
# 解析JSON响应
questions, reasoning_content = parse_md_json(question_response)
print(prompt)
print(question_response)
for question in questions:
await self.record_conflict(
conflict_content=question["question"],
context=reasoning_content,
start_following=False,
)
return True
async def get_conflict_count(self) -> int:
"""
获取冲突记录数量
Returns:
int: 记录数量
"""
try:
return MemoryConflict.select().count()
except Exception as e:
logger.error(f"获取冲突记录数量时出错: {e}")
return 0
# 全局冲突追踪器实例
global_conflict_tracker = ConflictTracker()