This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -16,7 +16,7 @@ class CuriousDetector:
"""
好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容
"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.llm_request = LLMRequest(
@@ -27,7 +27,7 @@ class CuriousDetector:
self.last_detection_time: float = time.time()
self.min_interval_seconds: float = 60.0
self.min_messages: int = 20
def should_trigger(self) -> bool:
if time.time() - self.last_detection_time < self.min_interval_seconds:
return False
@@ -41,17 +41,17 @@ class CuriousDetector:
async def detect_questions(self, recent_messages: List) -> Optional[str]:
"""
检测最近消息中是否有需要提问的内容
Args:
recent_messages: 最近的消息列表
Returns:
Optional[str]: 如果检测到需要提问的内容返回问题文本否则返回None
"""
try:
if not recent_messages or len(recent_messages) < 2:
return None
# 构建聊天内容
chat_content_block, _ = build_readable_messages_with_id(
messages=recent_messages,
@@ -60,9 +60,9 @@ class CuriousDetector:
truncate=True,
show_actions=True,
)
# 问题跟踪功能已移除,不再检查已有问题
# 构建检测提示词
prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。
@@ -98,20 +98,20 @@ class CuriousDetector:
logger.debug("已发送好奇心检测提示词")
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3)
logger.info(f"好奇心检测提示词: {prompt}")
logger.info(f"好奇心检测结果: {result_text}")
if not result_text:
return None
result_text = result_text.strip()
# 检查是否输出NO
if result_text.upper() == "NO":
logger.debug("未检测到需要提问的内容")
return None
# 尝试解析JSON
try:
questions, reasoning = parse_md_json(result_text)
@@ -119,7 +119,7 @@ class CuriousDetector:
question_data = questions[0]
question = question_data.get("question", "")
reason = question_data.get("reason", "")
if question and question.strip():
logger.info(f"检测到需要提问的内容: {question}")
logger.info(f"提问理由: {reason}")
@@ -127,32 +127,32 @@ class CuriousDetector:
except Exception as e:
logger.warning(f"解析问题JSON失败: {e}")
logger.debug(f"原始响应: {result_text}")
return None
except Exception as e:
logger.error(f"好奇心检测失败: {e}")
return None
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
"""
将检测到的问题记录(已移除冲突追踪器功能)
Args:
question: 检测到的问题
context: 问题上下文
Returns:
bool: 是否成功记录
"""
try:
if not question or not question.strip():
return False
# 冲突追踪器功能已移除
logger.info(f"检测到问题(冲突追踪器已移除): {question}")
return True
except Exception as e:
logger.error(f"记录问题失败: {e}")
return False
@@ -174,11 +174,11 @@ curious_manager = CuriousManager()
async def check_and_make_question(chat_id: str) -> bool:
"""
检查聊天记录并生成问题(如果检测到需要提问的内容)
Args:
chat_id: 聊天ID
recent_messages: 最近的消息列表
Returns:
bool: 是否检测到并记录了问题
"""
@@ -199,7 +199,7 @@ async def check_and_make_question(chat_id: str) -> bool:
# 检测是否需要提问
question = await detector.detect_questions(recent_messages)
if question:
# 记录问题
success = await detector.make_question_from_detection(question)
@@ -207,9 +207,9 @@ async def check_and_make_question(chat_id: str) -> bool:
logger.info(f"成功检测并记录问题: {question}")
detector.last_detection_time = time.time()
return True
return False
except Exception as e:
logger.error(f"检查并生成问题失败: {e}")
return False

View File

@@ -19,7 +19,7 @@ def init_memory_retrieval_prompt():
"""初始化记忆检索相关的 prompt 模板和工具"""
# 首先注册所有工具
init_all_tools()
# 第一步问题生成prompt
Prompt(
"""
@@ -63,7 +63,7 @@ def init_memory_retrieval_prompt():
""",
name="memory_retrieval_question_prompt",
)
# 第二步ReAct Agent prompt工具描述会在运行时动态生成
Prompt(
"""
@@ -105,10 +105,10 @@ def init_memory_retrieval_prompt():
def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
"""解析ReAct Agent的响应
Args:
response: LLM返回的响应
Returns:
Dict[str, Any]: 解析后的动作信息如果解析失败返回None
格式: {"thought": str, "actions": List[Dict[str, Any]]}
@@ -118,58 +118,55 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
action_info = json.loads(repaired_json)
if not isinstance(action_info, dict):
logger.warning(f"解析的JSON不是对象格式: {action_info}")
return None
# 确保actions字段存在且为列表
if "actions" not in action_info:
logger.warning(f"响应中缺少actions字段: {action_info}")
return None
if not isinstance(action_info["actions"], list):
logger.warning(f"actions字段不是数组格式: {action_info['actions']}")
return None
# 确保actions不为空
if len(action_info["actions"]) == 0:
logger.warning("actions数组为空")
return None
return action_info
except Exception as e:
logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...")
return None
async def _react_agent_solve_question(
question: str,
chat_id: str,
max_iterations: int = 5,
timeout: float = 30.0
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
"""使用ReAct架构的Agent来解决问题
Args:
question: 要回答的问题
chat_id: 聊天ID
max_iterations: 最大迭代次数
timeout: 超时时间(秒)
Returns:
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
"""
@@ -177,26 +174,26 @@ async def _react_agent_solve_question(
collected_info = ""
thinking_steps = []
is_timeout = False
for iteration in range(max_iterations):
# 检查超时
if time.time() - start_time > timeout:
logger.warning(f"ReAct Agent超时已迭代{iteration}")
is_timeout = True
break
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}")
logger.info(f"ReAct Agent 已收集信息: {collected_info if collected_info else '暂无信息'}")
# 获取工具注册器
tool_registry = get_tool_registry()
# 获取bot_name
bot_name = global_config.bot.nickname
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
# 构建prompt动态生成工具描述
prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt",
@@ -207,44 +204,39 @@ async def _react_agent_solve_question(
tools_description=tool_registry.get_tools_description(),
action_types_list=tool_registry.get_action_types_list(),
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 Prompt: {prompt}")
# 调用LLM
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
prompt,
model_config=model_config.model_task_config.tool_use,
request_type="memory.react",
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM响应: {response}")
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM推理: {reasoning_content}")
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM模型: {model_name}")
if not success:
logger.error(f"ReAct Agent LLM调用失败: {response}")
break
# 解析响应
action_info = _parse_react_response(response)
if not action_info:
logger.warning(f"无法解析ReAct响应迭代{iteration + 1}")
break
thought = action_info.get("thought", "")
actions = action_info.get("actions", [])
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考: {thought}")
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作数量: {len(actions)}")
# 记录思考步骤包含所有actions
step = {
"iteration": iteration + 1,
"thought": thought,
"actions": actions,
"observations": []
}
step = {"iteration": iteration + 1, "thought": thought, "actions": actions, "observations": []}
# 检查是否有final_answer或no_answer
for action in actions:
action_type = action.get("action_type", "")
@@ -265,29 +257,32 @@ async def _react_agent_solve_question(
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}")
return False, answer, thinking_steps, False
# 并行执行所有工具
tool_registry = get_tool_registry()
tool_tasks = []
for i, action in enumerate(actions):
action_type = action.get("action_type", "")
action_params = action.get("action_params", {})
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)}: {action_type}({action_params})")
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)}: {action_type}({action_params})"
)
tool = tool_registry.get_tool(action_type)
if tool:
# 准备工具参数需要添加chat_id如果工具需要
tool_params = action_params.copy()
# 如果工具函数签名需要chat_id添加它
import inspect
sig = inspect.signature(tool.execute_func)
if "chat_id" in sig.parameters:
tool_params["chat_id"] = chat_id
# 创建异步任务
async def execute_single_tool(tool_instance, params, act_type, act_params, iter_num):
try:
@@ -298,34 +293,36 @@ async def _react_agent_solve_question(
error_msg = f"工具执行失败: {str(e)}"
logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 动作 {act_type} {error_msg}")
return f"查询{act_type}失败: {error_msg}"
tool_tasks.append(execute_single_tool(tool, tool_params, action_type, action_params, iteration))
else:
error_msg = f"未知的工具类型: {action_type}"
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)} {error_msg}")
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)} {error_msg}")
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{action_type}失败: {error_msg}")))
# 并行执行所有工具
if tool_tasks:
observations = await asyncio.gather(*tool_tasks, return_exceptions=True)
# 处理执行结果
for i, observation in enumerate(observations):
if isinstance(observation, Exception):
observation = f"工具执行异常: {str(observation)}"
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行异常: {observation}")
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行异常: {observation}")
step["observations"].append(observation)
collected_info += f"\n{observation}\n"
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行结果: {observation}")
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行结果: {observation}")
thinking_steps.append(step)
# 达到最大迭代次数或超时但Agent没有明确返回final_answer
# 迭代超时应该直接视为no_answer而不是使用已有信息
# 只有Agent明确返回final_answer时才认为找到了答案
if collected_info:
logger.warning(f"ReAct Agent达到最大迭代次数或超时但未明确返回final_answer。已收集信息: {collected_info[:100]}...")
logger.warning(
f"ReAct Agent达到最大迭代次数或超时但未明确返回final_answer。已收集信息: {collected_info[:100]}..."
)
if is_timeout:
logger.warning("ReAct Agent超时直接视为no_answer")
else:
@@ -335,35 +332,32 @@ async def _react_agent_solve_question(
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
"""获取最近一段时间内的查询历史
Args:
chat_id: 聊天ID
time_window_seconds: 时间窗口默认10分钟
Returns:
str: 格式化的查询历史字符串
"""
try:
current_time = time.time()
start_time = current_time - time_window_seconds
# 查询最近时间窗口内的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.update_time >= start_time)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
.order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录
)
if not records.exists():
return ""
history_lines = []
history_lines.append("最近已查询的问题和结果:")
for record in records:
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
answer_preview = ""
@@ -373,15 +367,15 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
answer_preview = record.answer[:100]
if len(record.answer) > 100:
answer_preview += "..."
history_lines.append(f"- 问题:{record.question}")
history_lines.append(f" 状态:{status}")
if answer_preview:
history_lines.append(f" 答案:{answer_preview}")
history_lines.append("") # 空行分隔
return "\n".join(history_lines)
except Exception as e:
logger.error(f"获取查询历史失败: {e}")
return ""
@@ -389,40 +383,40 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> List[str]:
"""获取最近一段时间内缓存的记忆(只返回找到答案的记录)
Args:
chat_id: 聊天ID
time_window_seconds: 时间窗口默认300秒5分钟
Returns:
List[str]: 格式化的记忆列表,每个元素格式为 "问题xxx\n答案xxx"
"""
try:
current_time = time.time()
start_time = current_time - time_window_seconds
# 查询最近时间窗口内找到答案的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.update_time >= start_time) &
(ThinkingBack.found_answer == 1)
(ThinkingBack.chat_id == chat_id)
& (ThinkingBack.update_time >= start_time)
& (ThinkingBack.found_answer == 1)
)
.order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录
)
if not records.exists():
return []
cached_memories = []
for record in records:
if record.answer:
cached_memories.append(f"问题:{record.question}\n答案:{record.answer}")
return cached_memories
except Exception as e:
logger.error(f"获取缓存记忆失败: {e}")
return []
@@ -430,11 +424,11 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]:
"""从thinking_back数据库中查询是否有现成的答案
Args:
chat_id: 聊天ID
question: 问题
Returns:
Optional[Tuple[bool, str]]: 如果找到记录,返回(found_answer, answer)否则返回None
found_answer: 是否找到答案True表示found_answer=1False表示found_answer=0
@@ -445,38 +439,30 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
# 按更新时间倒序,获取最新的记录
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
.order_by(ThinkingBack.update_time.desc())
.limit(1)
)
if records.exists():
record = records.get()
found_answer = bool(record.found_answer)
answer = record.answer or ""
logger.info(f"在thinking_back中找到记录问题: {question[:50]}...found_answer: {found_answer}")
return found_answer, answer
return None
except Exception as e:
logger.error(f"查询thinking_back失败: {e}")
return None
def _store_thinking_back(
chat_id: str,
question: str,
context: str,
found_answer: bool,
answer: str,
thinking_steps: List[Dict[str, Any]]
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
) -> None:
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
Args:
chat_id: 聊天ID
question: 问题
@@ -487,18 +473,15 @@ def _store_thinking_back(
"""
try:
now = time.time()
# 先查询是否已存在相同chat_id和问题的记录
existing = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id) &
(ThinkingBack.question == question)
)
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
.order_by(ThinkingBack.update_time.desc())
.limit(1)
)
if existing.exists():
# 更新现有记录
record = existing.get()
@@ -519,37 +502,33 @@ def _store_thinking_back(
answer=answer,
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
create_time=now,
update_time=now
update_time=now,
)
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
except Exception as e:
logger.error(f"存储思考过程失败: {e}")
async def _process_single_question(
question: str,
chat_id: str,
context: str
) -> Optional[str]:
async def _process_single_question(question: str, chat_id: str, context: str) -> Optional[str]:
"""处理单个问题的查询(包含缓存检查逻辑)
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
Returns:
Optional[str]: 如果找到答案返回格式化的结果字符串否则返回None
"""
logger.info(f"开始处理问题: {question}")
# 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question)
should_requery = False
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
@@ -561,7 +540,7 @@ async def _process_single_question(
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
# 如果不需要重新查询,使用缓存答案
if not should_requery:
if cached_answer:
@@ -570,21 +549,18 @@ async def _process_single_question(
else:
# 缓存中没有答案,需要查询
should_requery = True
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
if should_requery:
logger.info(f"概率触发重新查询使用ReAct Agent查询问题: {question[:50]}...")
else:
logger.info(f"未找到缓存答案使用ReAct Agent查询问题: {question[:50]}...")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=5,
timeout=120.0
question=question, chat_id=chat_id, max_iterations=5, timeout=120.0
)
# 存储到数据库(超时时不存储)
if not is_timeout:
_store_thinking_back(
@@ -593,14 +569,14 @@ async def _process_single_question(
context=context,
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps
thinking_steps=thinking_steps,
)
else:
logger.info(f"ReAct Agent超时不存储到数据库问题: {question[:50]}...")
if found_answer and answer:
return f"问题:{question}\n答案:{answer}"
return None
@@ -613,30 +589,30 @@ async def build_memory_retrieval_prompt(
) -> str:
"""构建记忆检索提示
使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案
Args:
message: 聊天历史记录
sender: 发送者名称
target: 目标消息内容
chat_stream: 聊天流对象
tool_executor: 工具执行器(保留参数以兼容接口)
Returns:
str: 记忆检索结果字符串
"""
start_time = time.time()
logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
try:
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
chat_id = chat_stream.stream_id
# 获取最近查询历史最近1小时内的查询
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0)
if not recent_query_history:
recent_query_history = "最近没有查询记录。"
# 第一步:生成问题
question_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_question_prompt",
@@ -647,55 +623,52 @@ async def build_memory_retrieval_prompt(
sender=sender,
target_message=target,
)
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
question_prompt,
model_config=model_config.model_task_config.tool_use,
request_type="memory.question",
)
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
logger.info(f"记忆检索问题生成响应: {response}")
if not success:
logger.error(f"LLM生成问题失败: {response}")
return ""
# 解析问题列表
questions = _parse_questions_json(response)
# 获取缓存的记忆与question时使用相同的时间窗口和数量限制
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
if not questions:
logger.debug("模型认为不需要检索记忆或解析失败")
# 即使没有当次查询,也返回缓存的记忆
if cached_memories:
retrieved_memory = "\n\n".join(cached_memories)
end_time = time.time()
logger.info(f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆")
logger.info(
f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
return ""
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
# 第二步并行处理所有问题固定使用5次迭代/120秒超时
logger.info(f"问题数量: {len(questions)},固定设置最大迭代次数: 5超时时间: 120秒")
# 并行处理所有问题
question_tasks = [
_process_single_question(
question=question,
chat_id=chat_id,
context=message
)
for question in questions
_process_single_question(question=question, chat_id=chat_id, context=message) for question in questions
]
# 并行执行所有查询任务
results = await asyncio.gather(*question_tasks, return_exceptions=True)
# 收集所有有效结果
all_results = []
current_questions = set() # 用于去重,避免缓存和当次查询重复
@@ -708,7 +681,7 @@ async def build_memory_retrieval_prompt(
if result.startswith("问题:"):
question = result.split("\n")[0].replace("问题:", "").strip()
current_questions.add(question)
# 将缓存的记忆添加到结果中(排除当次查询已包含的问题,避免重复)
for cached_memory in cached_memories:
if cached_memory.startswith("问题:"):
@@ -717,17 +690,19 @@ async def build_memory_retrieval_prompt(
if question not in current_questions:
all_results.append(cached_memory)
logger.debug(f"添加缓存记忆: {question[:50]}...")
end_time = time.time()
if all_results:
retrieved_memory = "\n\n".join(all_results)
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
logger.info(
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
logger.debug("所有问题均未找到答案,且无缓存记忆")
return ""
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
return ""
@@ -735,10 +710,10 @@ async def build_memory_retrieval_prompt(
def _parse_questions_json(response: str) -> List[str]:
"""解析问题JSON
Args:
response: LLM返回的响应
Returns:
List[str]: 问题列表
"""
@@ -746,28 +721,28 @@ def _parse_questions_json(response: str) -> List[str]:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
questions = json.loads(repaired_json)
if not isinstance(questions, list):
logger.warning(f"解析的JSON不是数组格式: {questions}")
return []
# 确保所有元素都是字符串
questions = [q for q in questions if isinstance(q, str) and q.strip()]
return questions
except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return []

View File

@@ -3,6 +3,7 @@
记忆系统工具函数
包含模糊查找、相似度计算等工具函数
"""
import json
import re
from difflib import SequenceMatcher
@@ -12,6 +13,7 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
@@ -50,14 +52,15 @@ def parse_md_json(json_text: str) -> list[str]:
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度分数 (0-1)
"""
@@ -65,16 +68,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
# 预处理文本
text1 = preprocess_text(text1)
text2 = preprocess_text(text2)
# 使用SequenceMatcher计算相似度
similarity = SequenceMatcher(None, text1, text2).ratio()
# 如果其中一个文本包含另一个,提高相似度
if text1 in text2 or text2 in text1:
similarity = max(similarity, 0.8)
return similarity
except Exception as e:
logger.error(f"计算相似度时出错: {e}")
return 0.0
@@ -83,26 +86,25 @@ def calculate_similarity(text1: str, text2: str) -> float:
def preprocess_text(text: str) -> str:
"""
预处理文本,提高匹配准确性
Args:
text: 原始文本
Returns:
str: 预处理后的文本
"""
try:
# 转换为小写
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
logger.error(f"预处理文本时出错: {e}")
return text

View File

@@ -14,20 +14,16 @@ from .tool_utils import parse_datetime_to_timestamp, parse_time_range
logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None
) -> str:
async def query_chat_history(chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
time_range: 时间范围或时间点,格式:
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
Returns:
str: 查询结果
"""
@@ -35,10 +31,10 @@ async def query_chat_history(
# 检查参数
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
if time_range:
# 判断是时间点还是时间范围
@@ -46,73 +42,71 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
if not records:
return "未找到相关聊天记录概述"
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含任意一个关键词OR关系
matched = False
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
if matched:
filtered_records.append(record)
if not filtered_records:
keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
records = filtered_records
# 对即将返回的记录增加使用计数
@@ -123,22 +117,23 @@ async def query_chat_history(
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建结果文本
results = []
for record in records_to_use: # 最多返回3条记录
result_parts = []
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
@@ -147,18 +142,18 @@ async def query_chat_history(
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
response_text = "\n\n---\n\n".join(results)
if len(records) > len(records_to_use):
omitted_count = len(records) - len(records_to_use)
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
return response_text
except Exception as e:
logger.error(f"查询聊天历史概述失败: {e}")
return f"查询失败: {str(e)}"
@@ -174,14 +169,14 @@ def register_tool():
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
"required": False
"required": False,
},
{
"name": "time_range",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
}
"required": False,
},
],
execute_func=query_chat_history
execute_func=query_chat_history,
)

View File

@@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str
) -> str:
async def query_jargon(keyword: str, chat_id: str) -> str:
"""根据关键词在jargon库中查询
Args:
keyword: 关键词(黑话/俚语/缩写)
chat_id: 聊天ID
Returns:
str: 查询结果
"""
@@ -26,29 +23,17 @@ async def query_jargon(
content = str(keyword).strip()
if not content:
return "关键词为空"
# 先尝试精确匹配
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not results:
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if results:
# 如果是模糊匹配显示找到的实际jargon内容
if is_fuzzy_match:
@@ -71,11 +56,11 @@ async def query_jargon(
output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0]
logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果")
return output
# 未命中
logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}"
@@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
execute_func=query_jargon,
)

View File

@@ -11,17 +11,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
@@ -32,7 +28,7 @@ class MemoryRetrievalTool:
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
@@ -43,10 +39,10 @@ class MemoryRetrievalTool:
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
@@ -54,30 +50,30 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt"""
action_types = [tool.name for tool in self.tools.values()]
@@ -91,13 +87,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
@@ -111,4 +104,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry

View File

@@ -40,25 +40,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
Args:
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp
return start_timestamp, end_timestamp