245 lines
8.6 KiB
Python
245 lines
8.6 KiB
Python
import time
|
||
import asyncio
|
||
from rich.traceback import install
|
||
from src.common.logger import get_logger
|
||
from src.common.database.database_model import MemoryConflict
|
||
from src.chat.utils.chat_message_builder import (
|
||
get_raw_msg_by_timestamp_with_chat,
|
||
build_readable_messages,
|
||
)
|
||
from src.llm_models.utils_model import LLMRequest
|
||
from src.config.config import model_config, global_config
|
||
|
||
logger = get_logger("conflict_tracker")
|
||
|
||
logger = get_logger("conflict_tracker")
|
||
|
||
install(extra_lines=3)
|
||
|
||
class QuestionTracker:
|
||
"""
|
||
用于跟踪一个问题在后续聊天中的解答情况
|
||
"""
|
||
|
||
def __init__(self, question: str, chat_id: str) -> None:
|
||
self.question = question
|
||
self.chat_id = chat_id
|
||
now = time.time()
|
||
self.start_time = now
|
||
self.last_read_time = now
|
||
self.active = True
|
||
# 将 LLM 实例作为类属性,使用 utils 模型
|
||
self.llm_request = LLMRequest(model_set=model_config.model_task_config.utils, request_type="conflict.judge")
|
||
|
||
def stop(self) -> None:
|
||
self.active = False
|
||
|
||
async def judge_answer(self, conversation_text: str) -> tuple[bool, str]:
|
||
"""
|
||
使用小模型判定问题是否已得到解答。
|
||
返回 (已解答, 答案)
|
||
"""
|
||
prompt = (
|
||
"你是一个严谨的判定器。下面给出聊天记录以及一个问题。\n"
|
||
"任务:判断在这段聊天中,该问题是否已经得到明确解答。或从聊天内容中可以整理出答案\n"
|
||
"如果已解答,请只输出:YES: <简短答案>\n"
|
||
"如果没有,请只输出:NO\n\n"
|
||
f"问题:{self.question}\n"
|
||
"聊天记录如下:\n"
|
||
f"{conversation_text}"
|
||
)
|
||
|
||
|
||
|
||
if global_config.debug.show_prompt:
|
||
logger.info(f"判定提示词: {prompt}")
|
||
else:
|
||
logger.debug("已发送判定提示词")
|
||
|
||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.2)
|
||
if not result_text:
|
||
return False, ""
|
||
|
||
logger.info(f"判定提示词: {prompt},问题: {self.question},result: {result_text}")
|
||
|
||
text = result_text.strip()
|
||
if text.upper().startswith("YES:"):
|
||
answer = text[4:].strip()
|
||
return True, answer
|
||
if text.upper().startswith("YES"):
|
||
# 兼容仅输出 YES 或 YES <answer>
|
||
answer = text[3:].strip().lstrip(":").strip()
|
||
return True, answer
|
||
return False, ""
|
||
|
||
class ConflictTracker:
|
||
"""
|
||
记忆整合冲突追踪器
|
||
|
||
用于记录和存储记忆整合过程中的冲突内容
|
||
"""
|
||
|
||
async def record_conflict(self, conflict_content: str, start_following: bool = False,chat_id: str = "") -> bool:
|
||
"""
|
||
记录冲突内容
|
||
|
||
Args:k
|
||
conflict_content: 冲突内容
|
||
|
||
Returns:
|
||
bool: 是否成功记录
|
||
"""
|
||
try:
|
||
if not conflict_content or conflict_content.strip() == "":
|
||
return False
|
||
|
||
# 若需要跟随后续消息以判断是否得到解答,则进入跟踪流程
|
||
if start_following and chat_id:
|
||
tracker = QuestionTracker(conflict_content.strip(), chat_id)
|
||
# 后台启动跟踪任务,避免阻塞
|
||
asyncio.create_task(self._follow_and_record(tracker, conflict_content.strip()))
|
||
return True
|
||
|
||
# 默认:直接记录,不进行跟踪
|
||
MemoryConflict.create(
|
||
conflict_content=conflict_content,
|
||
create_time=time.time(),
|
||
update_time=time.time(),
|
||
answer="",
|
||
)
|
||
|
||
logger.info(f"记录冲突内容: {len(conflict_content)} 字符")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"记录冲突内容时出错: {e}")
|
||
return False
|
||
|
||
async def _follow_and_record(self, tracker: QuestionTracker, original_question: str) -> None:
|
||
"""
|
||
后台任务:跟踪问题是否被解答,并写入数据库。
|
||
"""
|
||
try:
|
||
max_duration = 30 * 60 # 30 分钟
|
||
max_messages = 100 # 最多 100 条消息
|
||
poll_interval = 2.0 # 秒
|
||
|
||
while tracker.active:
|
||
now_ts = time.time()
|
||
# 终止条件:时长达到上限
|
||
if now_ts - tracker.start_time >= max_duration:
|
||
logger.info("问题跟踪达到30分钟上限,判定为未解答")
|
||
break
|
||
|
||
# 统计最近一段是否有新消息(不过滤机器人,过滤命令)
|
||
recent_msgs = get_raw_msg_by_timestamp_with_chat(
|
||
chat_id=tracker.chat_id,
|
||
timestamp_start=tracker.last_read_time,
|
||
timestamp_end=now_ts,
|
||
limit=0,
|
||
limit_mode="latest",
|
||
filter_bot=False,
|
||
filter_command=True,
|
||
)
|
||
|
||
if len(recent_msgs) > 0:
|
||
tracker.last_read_time = now_ts
|
||
|
||
# 统计从开始到现在的总消息数(用于触发100条上限)
|
||
all_msgs = get_raw_msg_by_timestamp_with_chat(
|
||
chat_id=tracker.chat_id,
|
||
timestamp_start=tracker.start_time,
|
||
timestamp_end=now_ts,
|
||
limit=0,
|
||
limit_mode="latest",
|
||
filter_bot=False,
|
||
filter_command=True,
|
||
)
|
||
|
||
# 构建可读聊天文本
|
||
chat_text = build_readable_messages(
|
||
all_msgs,
|
||
replace_bot_name=True,
|
||
timestamp_mode="relative",
|
||
read_mark=0.0,
|
||
truncate=False,
|
||
show_actions=False,
|
||
show_pic=False,
|
||
remove_emoji_stickers=True,
|
||
)
|
||
|
||
# 让小模型判断是否有答案
|
||
answered, answer_text = await tracker.judge_answer(chat_text)
|
||
if answered:
|
||
logger.info("问题已得到解答,结束跟踪并写入答案")
|
||
tracker.stop()
|
||
MemoryConflict.create(
|
||
conflict_content=tracker.question,
|
||
create_time=tracker.start_time,
|
||
update_time=time.time(),
|
||
answer=answer_text or "",
|
||
)
|
||
return
|
||
|
||
if len(all_msgs) >= max_messages:
|
||
logger.info("问题跟踪达到100条消息上限,判定为未解答")
|
||
break
|
||
|
||
# 无新消息时稍作等待
|
||
await asyncio.sleep(poll_interval)
|
||
|
||
# 未获取到答案,仅存储问题
|
||
MemoryConflict.create(
|
||
conflict_content=original_question,
|
||
create_time=time.time(),
|
||
update_time=time.time(),
|
||
answer="",
|
||
)
|
||
logger.info(f"记录冲突内容(未解答): {len(original_question)} 字符")
|
||
except Exception as e:
|
||
logger.error(f"后台问题跟踪任务异常: {e}")
|
||
|
||
async def record_memory_merge_conflict(self, part2_content: str) -> bool:
|
||
"""
|
||
记录记忆整合过程中的冲突内容(part2)
|
||
|
||
Args:
|
||
part2_content: 冲突内容(part2)
|
||
|
||
Returns:
|
||
bool: 是否成功记录
|
||
"""
|
||
if not part2_content or part2_content.strip() == "":
|
||
return False
|
||
|
||
return await self.record_conflict(part2_content)
|
||
|
||
async def get_all_conflicts(self) -> list:
|
||
"""
|
||
获取所有冲突记录
|
||
|
||
Returns:
|
||
list: 冲突记录列表
|
||
"""
|
||
try:
|
||
conflicts = list(MemoryConflict.select())
|
||
return conflicts
|
||
except Exception as e:
|
||
logger.error(f"获取冲突记录时出错: {e}")
|
||
return []
|
||
|
||
async def get_conflict_count(self) -> int:
|
||
"""
|
||
获取冲突记录数量
|
||
|
||
Returns:
|
||
int: 记录数量
|
||
"""
|
||
try:
|
||
return MemoryConflict.select().count()
|
||
except Exception as e:
|
||
logger.error(f"获取冲突记录数量时出错: {e}")
|
||
return 0
|
||
|
||
# 全局冲突追踪器实例
|
||
global_conflict_tracker = ConflictTracker() |