Ruff fix
This commit is contained in:
@@ -16,7 +16,7 @@ class CuriousDetector:
|
||||
"""
|
||||
好奇心检测器 - 检测聊天记录中的矛盾、冲突或需要提问的内容
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.llm_request = LLMRequest(
|
||||
@@ -27,7 +27,7 @@ class CuriousDetector:
|
||||
self.last_detection_time: float = time.time()
|
||||
self.min_interval_seconds: float = 60.0
|
||||
self.min_messages: int = 20
|
||||
|
||||
|
||||
def should_trigger(self) -> bool:
|
||||
if time.time() - self.last_detection_time < self.min_interval_seconds:
|
||||
return False
|
||||
@@ -41,17 +41,17 @@ class CuriousDetector:
|
||||
async def detect_questions(self, recent_messages: List) -> Optional[str]:
|
||||
"""
|
||||
检测最近消息中是否有需要提问的内容
|
||||
|
||||
|
||||
Args:
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果检测到需要提问的内容,返回问题文本;否则返回None
|
||||
"""
|
||||
try:
|
||||
if not recent_messages or len(recent_messages) < 2:
|
||||
return None
|
||||
|
||||
|
||||
# 构建聊天内容
|
||||
chat_content_block, _ = build_readable_messages_with_id(
|
||||
messages=recent_messages,
|
||||
@@ -60,9 +60,9 @@ class CuriousDetector:
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
|
||||
# 问题跟踪功能已移除,不再检查已有问题
|
||||
|
||||
|
||||
# 构建检测提示词
|
||||
prompt = f"""你是一个严谨的聊天内容分析器。请分析以下聊天记录,检测是否存在需要提问的内容。
|
||||
|
||||
@@ -98,20 +98,20 @@ class CuriousDetector:
|
||||
logger.debug("已发送好奇心检测提示词")
|
||||
|
||||
result_text, _ = await self.llm_request.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
|
||||
logger.info(f"好奇心检测提示词: {prompt}")
|
||||
logger.info(f"好奇心检测结果: {result_text}")
|
||||
|
||||
|
||||
if not result_text:
|
||||
return None
|
||||
|
||||
|
||||
result_text = result_text.strip()
|
||||
|
||||
|
||||
# 检查是否输出NO
|
||||
if result_text.upper() == "NO":
|
||||
logger.debug("未检测到需要提问的内容")
|
||||
return None
|
||||
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
questions, reasoning = parse_md_json(result_text)
|
||||
@@ -119,7 +119,7 @@ class CuriousDetector:
|
||||
question_data = questions[0]
|
||||
question = question_data.get("question", "")
|
||||
reason = question_data.get("reason", "")
|
||||
|
||||
|
||||
if question and question.strip():
|
||||
logger.info(f"检测到需要提问的内容: {question}")
|
||||
logger.info(f"提问理由: {reason}")
|
||||
@@ -127,32 +127,32 @@ class CuriousDetector:
|
||||
except Exception as e:
|
||||
logger.warning(f"解析问题JSON失败: {e}")
|
||||
logger.debug(f"原始响应: {result_text}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"好奇心检测失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def make_question_from_detection(self, question: str, context: str = "") -> bool:
|
||||
"""
|
||||
将检测到的问题记录(已移除冲突追踪器功能)
|
||||
|
||||
|
||||
Args:
|
||||
question: 检测到的问题
|
||||
context: 问题上下文
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否成功记录
|
||||
"""
|
||||
try:
|
||||
if not question or not question.strip():
|
||||
return False
|
||||
|
||||
|
||||
# 冲突追踪器功能已移除
|
||||
logger.info(f"检测到问题(冲突追踪器已移除): {question}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记录问题失败: {e}")
|
||||
return False
|
||||
@@ -174,11 +174,11 @@ curious_manager = CuriousManager()
|
||||
async def check_and_make_question(chat_id: str) -> bool:
|
||||
"""
|
||||
检查聊天记录并生成问题(如果检测到需要提问的内容)
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
recent_messages: 最近的消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 是否检测到并记录了问题
|
||||
"""
|
||||
@@ -199,7 +199,7 @@ async def check_and_make_question(chat_id: str) -> bool:
|
||||
|
||||
# 检测是否需要提问
|
||||
question = await detector.detect_questions(recent_messages)
|
||||
|
||||
|
||||
if question:
|
||||
# 记录问题
|
||||
success = await detector.make_question_from_detection(question)
|
||||
@@ -207,9 +207,9 @@ async def check_and_make_question(chat_id: str) -> bool:
|
||||
logger.info(f"成功检测并记录问题: {question}")
|
||||
detector.last_detection_time = time.time()
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查并生成问题失败: {e}")
|
||||
return False
|
||||
|
||||
@@ -19,7 +19,7 @@ def init_memory_retrieval_prompt():
|
||||
"""初始化记忆检索相关的 prompt 模板和工具"""
|
||||
# 首先注册所有工具
|
||||
init_all_tools()
|
||||
|
||||
|
||||
# 第一步:问题生成prompt
|
||||
Prompt(
|
||||
"""
|
||||
@@ -63,7 +63,7 @@ def init_memory_retrieval_prompt():
|
||||
""",
|
||||
name="memory_retrieval_question_prompt",
|
||||
)
|
||||
|
||||
|
||||
# 第二步:ReAct Agent prompt(工具描述会在运行时动态生成)
|
||||
Prompt(
|
||||
"""
|
||||
@@ -105,10 +105,10 @@ def init_memory_retrieval_prompt():
|
||||
|
||||
def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析ReAct Agent的响应
|
||||
|
||||
|
||||
Args:
|
||||
response: LLM返回的响应
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 解析后的动作信息,如果解析失败返回None
|
||||
格式: {"thought": str, "actions": List[Dict[str, Any]]}
|
||||
@@ -118,58 +118,55 @@ def _parse_react_response(response: str) -> Optional[Dict[str, Any]]:
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
|
||||
# 解析JSON
|
||||
action_info = json.loads(repaired_json)
|
||||
|
||||
|
||||
if not isinstance(action_info, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {action_info}")
|
||||
return None
|
||||
|
||||
|
||||
# 确保actions字段存在且为列表
|
||||
if "actions" not in action_info:
|
||||
logger.warning(f"响应中缺少actions字段: {action_info}")
|
||||
return None
|
||||
|
||||
|
||||
if not isinstance(action_info["actions"], list):
|
||||
logger.warning(f"actions字段不是数组格式: {action_info['actions']}")
|
||||
return None
|
||||
|
||||
|
||||
# 确保actions不为空
|
||||
if len(action_info["actions"]) == 0:
|
||||
logger.warning("actions数组为空")
|
||||
return None
|
||||
|
||||
|
||||
return action_info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析ReAct响应失败: {e}, 响应内容: {response[:200]}...")
|
||||
return None
|
||||
|
||||
|
||||
async def _react_agent_solve_question(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
max_iterations: int = 5,
|
||||
timeout: float = 30.0
|
||||
question: str, chat_id: str, max_iterations: int = 5, timeout: float = 30.0
|
||||
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
|
||||
"""使用ReAct架构的Agent来解决问题
|
||||
|
||||
|
||||
Args:
|
||||
question: 要回答的问题
|
||||
chat_id: 聊天ID
|
||||
max_iterations: 最大迭代次数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
|
||||
"""
|
||||
@@ -177,26 +174,26 @@ async def _react_agent_solve_question(
|
||||
collected_info = ""
|
||||
thinking_steps = []
|
||||
is_timeout = False
|
||||
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
# 检查超时
|
||||
if time.time() - start_time > timeout:
|
||||
logger.warning(f"ReAct Agent超时,已迭代{iteration}次")
|
||||
is_timeout = True
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}")
|
||||
logger.info(f"ReAct Agent 已收集信息: {collected_info if collected_info else '暂无信息'}")
|
||||
|
||||
|
||||
# 获取工具注册器
|
||||
tool_registry = get_tool_registry()
|
||||
|
||||
|
||||
# 获取bot_name
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
|
||||
# 构建prompt(动态生成工具描述)
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_react_prompt",
|
||||
@@ -207,44 +204,39 @@ async def _react_agent_solve_question(
|
||||
tools_description=tool_registry.get_tools_description(),
|
||||
action_types_list=tool_registry.get_action_types_list(),
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 Prompt: {prompt}")
|
||||
|
||||
|
||||
# 调用LLM
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
request_type="memory.react",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM响应: {response}")
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM推理: {reasoning_content}")
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 LLM模型: {model_name}")
|
||||
|
||||
|
||||
if not success:
|
||||
logger.error(f"ReAct Agent LLM调用失败: {response}")
|
||||
break
|
||||
|
||||
|
||||
# 解析响应
|
||||
action_info = _parse_react_response(response)
|
||||
if not action_info:
|
||||
logger.warning(f"无法解析ReAct响应,迭代{iteration + 1}")
|
||||
break
|
||||
|
||||
|
||||
thought = action_info.get("thought", "")
|
||||
actions = action_info.get("actions", [])
|
||||
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考: {thought}")
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作数量: {len(actions)}")
|
||||
|
||||
|
||||
# 记录思考步骤(包含所有actions)
|
||||
step = {
|
||||
"iteration": iteration + 1,
|
||||
"thought": thought,
|
||||
"actions": actions,
|
||||
"observations": []
|
||||
}
|
||||
|
||||
step = {"iteration": iteration + 1, "thought": thought, "actions": actions, "observations": []}
|
||||
|
||||
# 检查是否有final_answer或no_answer
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "")
|
||||
@@ -265,29 +257,32 @@ async def _react_agent_solve_question(
|
||||
thinking_steps.append(step)
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 确认无法找到答案: {answer}")
|
||||
return False, answer, thinking_steps, False
|
||||
|
||||
|
||||
# 并行执行所有工具
|
||||
tool_registry = get_tool_registry()
|
||||
tool_tasks = []
|
||||
|
||||
|
||||
for i, action in enumerate(actions):
|
||||
action_type = action.get("action_type", "")
|
||||
action_params = action.get("action_params", {})
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)}: {action_type}({action_params})")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)}: {action_type}({action_params})"
|
||||
)
|
||||
|
||||
tool = tool_registry.get_tool(action_type)
|
||||
|
||||
|
||||
if tool:
|
||||
# 准备工具参数(需要添加chat_id如果工具需要)
|
||||
tool_params = action_params.copy()
|
||||
|
||||
|
||||
# 如果工具函数签名需要chat_id,添加它
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(tool.execute_func)
|
||||
if "chat_id" in sig.parameters:
|
||||
tool_params["chat_id"] = chat_id
|
||||
|
||||
|
||||
# 创建异步任务
|
||||
async def execute_single_tool(tool_instance, params, act_type, act_params, iter_num):
|
||||
try:
|
||||
@@ -298,34 +293,36 @@ async def _react_agent_solve_question(
|
||||
error_msg = f"工具执行失败: {str(e)}"
|
||||
logger.error(f"ReAct Agent 第 {iter_num + 1} 次迭代 动作 {act_type} {error_msg}")
|
||||
return f"查询{act_type}失败: {error_msg}"
|
||||
|
||||
|
||||
tool_tasks.append(execute_single_tool(tool, tool_params, action_type, action_params, iteration))
|
||||
else:
|
||||
error_msg = f"未知的工具类型: {action_type}"
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1}/{len(actions)} {error_msg}")
|
||||
logger.warning(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1}/{len(actions)} {error_msg}")
|
||||
tool_tasks.append(asyncio.create_task(asyncio.sleep(0, result=f"查询{action_type}失败: {error_msg}")))
|
||||
|
||||
|
||||
# 并行执行所有工具
|
||||
if tool_tasks:
|
||||
observations = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理执行结果
|
||||
for i, observation in enumerate(observations):
|
||||
if isinstance(observation, Exception):
|
||||
observation = f"工具执行异常: {str(observation)}"
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行异常: {observation}")
|
||||
|
||||
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行异常: {observation}")
|
||||
|
||||
step["observations"].append(observation)
|
||||
collected_info += f"\n{observation}\n"
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i+1} 执行结果: {observation}")
|
||||
|
||||
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 动作 {i + 1} 执行结果: {observation}")
|
||||
|
||||
thinking_steps.append(step)
|
||||
|
||||
|
||||
# 达到最大迭代次数或超时,但Agent没有明确返回final_answer
|
||||
# 迭代超时应该直接视为no_answer,而不是使用已有信息
|
||||
# 只有Agent明确返回final_answer时,才认为找到了答案
|
||||
if collected_info:
|
||||
logger.warning(f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}...")
|
||||
logger.warning(
|
||||
f"ReAct Agent达到最大迭代次数或超时,但未明确返回final_answer。已收集信息: {collected_info[:100]}..."
|
||||
)
|
||||
if is_timeout:
|
||||
logger.warning("ReAct Agent超时,直接视为no_answer")
|
||||
else:
|
||||
@@ -335,35 +332,32 @@ async def _react_agent_solve_question(
|
||||
|
||||
def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0) -> str:
|
||||
"""获取最近一段时间内的查询历史
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认10分钟
|
||||
|
||||
|
||||
Returns:
|
||||
str: 格式化的查询历史字符串
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
|
||||
|
||||
# 查询最近时间窗口内的记录,按更新时间倒序
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.update_time >= start_time)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
)
|
||||
|
||||
|
||||
if not records.exists():
|
||||
return ""
|
||||
|
||||
|
||||
history_lines = []
|
||||
history_lines.append("最近已查询的问题和结果:")
|
||||
|
||||
|
||||
for record in records:
|
||||
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
|
||||
answer_preview = ""
|
||||
@@ -373,15 +367,15 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
||||
answer_preview = record.answer[:100]
|
||||
if len(record.answer) > 100:
|
||||
answer_preview += "..."
|
||||
|
||||
|
||||
history_lines.append(f"- 问题:{record.question}")
|
||||
history_lines.append(f" 状态:{status}")
|
||||
if answer_preview:
|
||||
history_lines.append(f" 答案:{answer_preview}")
|
||||
history_lines.append("") # 空行分隔
|
||||
|
||||
|
||||
return "\n".join(history_lines)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取查询历史失败: {e}")
|
||||
return ""
|
||||
@@ -389,40 +383,40 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 300.0)
|
||||
|
||||
def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> List[str]:
|
||||
"""获取最近一段时间内缓存的记忆(只返回找到答案的记录)
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_window_seconds: 时间窗口(秒),默认300秒(5分钟)
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 格式化的记忆列表,每个元素格式为 "问题:xxx\n答案:xxx"
|
||||
"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
start_time = current_time - time_window_seconds
|
||||
|
||||
|
||||
# 查询最近时间窗口内找到答案的记录,按更新时间倒序
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.update_time >= start_time) &
|
||||
(ThinkingBack.found_answer == 1)
|
||||
(ThinkingBack.chat_id == chat_id)
|
||||
& (ThinkingBack.update_time >= start_time)
|
||||
& (ThinkingBack.found_answer == 1)
|
||||
)
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(5) # 最多返回5条最近的记录
|
||||
)
|
||||
|
||||
|
||||
if not records.exists():
|
||||
return []
|
||||
|
||||
|
||||
cached_memories = []
|
||||
for record in records:
|
||||
if record.answer:
|
||||
cached_memories.append(f"问题:{record.question}\n答案:{record.answer}")
|
||||
|
||||
|
||||
return cached_memories
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存记忆失败: {e}")
|
||||
return []
|
||||
@@ -430,11 +424,11 @@ def _get_cached_memories(chat_id: str, time_window_seconds: float = 300.0) -> Li
|
||||
|
||||
def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, str]]:
|
||||
"""从thinking_back数据库中查询是否有现成的答案
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
question: 问题
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[bool, str]]: 如果找到记录,返回(found_answer, answer),否则返回None
|
||||
found_answer: 是否找到答案(True表示found_answer=1,False表示found_answer=0)
|
||||
@@ -445,38 +439,30 @@ def _query_thinking_back(chat_id: str, question: str) -> Optional[Tuple[bool, st
|
||||
# 按更新时间倒序,获取最新的记录
|
||||
records = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.question == question)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
if records.exists():
|
||||
record = records.get()
|
||||
found_answer = bool(record.found_answer)
|
||||
answer = record.answer or ""
|
||||
logger.info(f"在thinking_back中找到记录,问题: {question[:50]}...,found_answer: {found_answer}")
|
||||
return found_answer, answer
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询thinking_back失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _store_thinking_back(
|
||||
chat_id: str,
|
||||
question: str,
|
||||
context: str,
|
||||
found_answer: bool,
|
||||
answer: str,
|
||||
thinking_steps: List[Dict[str, Any]]
|
||||
chat_id: str, question: str, context: str, found_answer: bool, answer: str, thinking_steps: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""存储或更新思考过程到数据库(如果已存在则更新,否则创建)
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
question: 问题
|
||||
@@ -487,18 +473,15 @@ def _store_thinking_back(
|
||||
"""
|
||||
try:
|
||||
now = time.time()
|
||||
|
||||
|
||||
# 先查询是否已存在相同chat_id和问题的记录
|
||||
existing = (
|
||||
ThinkingBack.select()
|
||||
.where(
|
||||
(ThinkingBack.chat_id == chat_id) &
|
||||
(ThinkingBack.question == question)
|
||||
)
|
||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
||||
.order_by(ThinkingBack.update_time.desc())
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
if existing.exists():
|
||||
# 更新现有记录
|
||||
record = existing.get()
|
||||
@@ -519,37 +502,33 @@ def _store_thinking_back(
|
||||
answer=answer,
|
||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||
create_time=now,
|
||||
update_time=now
|
||||
update_time=now,
|
||||
)
|
||||
logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"存储思考过程失败: {e}")
|
||||
|
||||
|
||||
async def _process_single_question(
|
||||
question: str,
|
||||
chat_id: str,
|
||||
context: str
|
||||
) -> Optional[str]:
|
||||
async def _process_single_question(question: str, chat_id: str, context: str) -> Optional[str]:
|
||||
"""处理单个问题的查询(包含缓存检查逻辑)
|
||||
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
|
||||
|
||||
Returns:
|
||||
Optional[str]: 如果找到答案,返回格式化的结果字符串,否则返回None
|
||||
"""
|
||||
logger.info(f"开始处理问题: {question}")
|
||||
|
||||
|
||||
# 先检查thinking_back数据库中是否有现成答案
|
||||
cached_result = _query_thinking_back(chat_id, question)
|
||||
should_requery = False
|
||||
|
||||
|
||||
if cached_result:
|
||||
cached_found_answer, cached_answer = cached_result
|
||||
|
||||
|
||||
# 根据found_answer的值决定是否重新查询
|
||||
if cached_found_answer: # found_answer == 1 (True)
|
||||
# found_answer == 1:20%概率重新查询
|
||||
@@ -561,7 +540,7 @@ async def _process_single_question(
|
||||
if random.random() < 0.4:
|
||||
should_requery = True
|
||||
logger.info(f"found_answer=0,触发40%概率重新查询,问题: {question[:50]}...")
|
||||
|
||||
|
||||
# 如果不需要重新查询,使用缓存答案
|
||||
if not should_requery:
|
||||
if cached_answer:
|
||||
@@ -570,21 +549,18 @@ async def _process_single_question(
|
||||
else:
|
||||
# 缓存中没有答案,需要查询
|
||||
should_requery = True
|
||||
|
||||
|
||||
# 如果没有缓存答案或需要重新查询,使用ReAct Agent查询
|
||||
if not cached_result or should_requery:
|
||||
if should_requery:
|
||||
logger.info(f"概率触发重新查询,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
else:
|
||||
logger.info(f"未找到缓存答案,使用ReAct Agent查询,问题: {question[:50]}...")
|
||||
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=5,
|
||||
timeout=120.0
|
||||
question=question, chat_id=chat_id, max_iterations=5, timeout=120.0
|
||||
)
|
||||
|
||||
|
||||
# 存储到数据库(超时时不存储)
|
||||
if not is_timeout:
|
||||
_store_thinking_back(
|
||||
@@ -593,14 +569,14 @@ async def _process_single_question(
|
||||
context=context,
|
||||
found_answer=found_answer,
|
||||
answer=answer,
|
||||
thinking_steps=thinking_steps
|
||||
thinking_steps=thinking_steps,
|
||||
)
|
||||
else:
|
||||
logger.info(f"ReAct Agent超时,不存储到数据库,问题: {question[:50]}...")
|
||||
|
||||
|
||||
if found_answer and answer:
|
||||
return f"问题:{question}\n答案:{answer}"
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -613,30 +589,30 @@ async def build_memory_retrieval_prompt(
|
||||
) -> str:
|
||||
"""构建记忆检索提示
|
||||
使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案
|
||||
|
||||
|
||||
Args:
|
||||
message: 聊天历史记录
|
||||
sender: 发送者名称
|
||||
target: 目标消息内容
|
||||
chat_stream: 聊天流对象
|
||||
tool_executor: 工具执行器(保留参数以兼容接口)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 记忆检索结果字符串
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
logger.info(f"检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
try:
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
bot_name = global_config.bot.nickname
|
||||
chat_id = chat_stream.stream_id
|
||||
|
||||
|
||||
# 获取最近查询历史(最近1小时内的查询)
|
||||
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=300.0)
|
||||
if not recent_query_history:
|
||||
recent_query_history = "最近没有查询记录。"
|
||||
|
||||
|
||||
# 第一步:生成问题
|
||||
question_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_retrieval_question_prompt",
|
||||
@@ -647,55 +623,52 @@ async def build_memory_retrieval_prompt(
|
||||
sender=sender,
|
||||
target_message=target,
|
||||
)
|
||||
|
||||
|
||||
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
|
||||
question_prompt,
|
||||
model_config=model_config.model_task_config.tool_use,
|
||||
request_type="memory.question",
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"记忆检索问题生成提示词: {question_prompt}")
|
||||
logger.info(f"记忆检索问题生成响应: {response}")
|
||||
|
||||
|
||||
if not success:
|
||||
logger.error(f"LLM生成问题失败: {response}")
|
||||
return ""
|
||||
|
||||
|
||||
# 解析问题列表
|
||||
questions = _parse_questions_json(response)
|
||||
|
||||
|
||||
# 获取缓存的记忆(与question时使用相同的时间窗口和数量限制)
|
||||
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
|
||||
|
||||
|
||||
if not questions:
|
||||
logger.debug("模型认为不需要检索记忆或解析失败")
|
||||
# 即使没有当次查询,也返回缓存的记忆
|
||||
if cached_memories:
|
||||
retrieved_memory = "\n\n".join(cached_memories)
|
||||
end_time = time.time()
|
||||
logger.info(f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆")
|
||||
logger.info(
|
||||
f"无当次查询,返回缓存记忆,耗时: {(end_time - start_time):.3f}秒,包含 {len(cached_memories)} 条缓存记忆"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
|
||||
|
||||
|
||||
# 第二步:并行处理所有问题(固定使用5次迭代/120秒超时)
|
||||
logger.info(f"问题数量: {len(questions)},固定设置最大迭代次数: 5,超时时间: 120秒")
|
||||
|
||||
|
||||
# 并行处理所有问题
|
||||
question_tasks = [
|
||||
_process_single_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
context=message
|
||||
)
|
||||
for question in questions
|
||||
_process_single_question(question=question, chat_id=chat_id, context=message) for question in questions
|
||||
]
|
||||
|
||||
|
||||
# 并行执行所有查询任务
|
||||
results = await asyncio.gather(*question_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 收集所有有效结果
|
||||
all_results = []
|
||||
current_questions = set() # 用于去重,避免缓存和当次查询重复
|
||||
@@ -708,7 +681,7 @@ async def build_memory_retrieval_prompt(
|
||||
if result.startswith("问题:"):
|
||||
question = result.split("\n")[0].replace("问题:", "").strip()
|
||||
current_questions.add(question)
|
||||
|
||||
|
||||
# 将缓存的记忆添加到结果中(排除当次查询已包含的问题,避免重复)
|
||||
for cached_memory in cached_memories:
|
||||
if cached_memory.startswith("问题:"):
|
||||
@@ -717,17 +690,19 @@ async def build_memory_retrieval_prompt(
|
||||
if question not in current_questions:
|
||||
all_results.append(cached_memory)
|
||||
logger.debug(f"添加缓存记忆: {question[:50]}...")
|
||||
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
||||
if all_results:
|
||||
retrieved_memory = "\n\n".join(all_results)
|
||||
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)")
|
||||
logger.info(
|
||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(all_results)} 条记忆(含缓存)"
|
||||
)
|
||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||
else:
|
||||
logger.debug("所有问题均未找到答案,且无缓存记忆")
|
||||
return ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆检索时发生异常: {str(e)}")
|
||||
return ""
|
||||
@@ -735,10 +710,10 @@ async def build_memory_retrieval_prompt(
|
||||
|
||||
def _parse_questions_json(response: str) -> List[str]:
|
||||
"""解析问题JSON
|
||||
|
||||
|
||||
Args:
|
||||
response: LLM返回的响应
|
||||
|
||||
|
||||
Returns:
|
||||
List[str]: 问题列表
|
||||
"""
|
||||
@@ -746,28 +721,28 @@ def _parse_questions_json(response: str) -> List[str]:
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
|
||||
# 解析JSON
|
||||
questions = json.loads(repaired_json)
|
||||
|
||||
|
||||
if not isinstance(questions, list):
|
||||
logger.warning(f"解析的JSON不是数组格式: {questions}")
|
||||
return []
|
||||
|
||||
|
||||
# 确保所有元素都是字符串
|
||||
questions = [q for q in questions if isinstance(q, str) and q.strip()]
|
||||
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return []
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
@@ -12,6 +13,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_md_json(json_text: str) -> list[str]:
|
||||
"""从Markdown格式的内容中提取JSON对象和推理内容"""
|
||||
json_objects = []
|
||||
@@ -50,14 +52,15 @@ def parse_md_json(json_text: str) -> list[str]:
|
||||
|
||||
return json_objects, reasoning_content
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
|
||||
Returns:
|
||||
float: 相似度分数 (0-1)
|
||||
"""
|
||||
@@ -65,16 +68,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
# 预处理文本
|
||||
text1 = preprocess_text(text1)
|
||||
text2 = preprocess_text(text2)
|
||||
|
||||
|
||||
# 使用SequenceMatcher计算相似度
|
||||
similarity = SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
# 如果其中一个文本包含另一个,提高相似度
|
||||
if text1 in text2 or text2 in text1:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
|
||||
return similarity
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算相似度时出错: {e}")
|
||||
return 0.0
|
||||
@@ -83,26 +86,25 @@ def calculate_similarity(text1: str, text2: str) -> float:
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本,提高匹配准确性
|
||||
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
|
||||
Returns:
|
||||
str: 预处理后的文本
|
||||
"""
|
||||
try:
|
||||
# 转换为小写
|
||||
text = text.lower()
|
||||
|
||||
|
||||
# 移除标点符号和特殊字符
|
||||
text = re.sub(r'[^\w\s]', '', text)
|
||||
|
||||
text = re.sub(r"[^\w\s]", "", text)
|
||||
|
||||
# 移除多余空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预处理文本时出错: {e}")
|
||||
return text
|
||||
|
||||
|
||||
@@ -14,20 +14,16 @@ from .tool_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_range: Optional[str] = None
|
||||
) -> str:
|
||||
async def query_chat_history(chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
|
||||
time_range: 时间范围或时间点,格式:
|
||||
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
@@ -35,10 +31,10 @@ async def query_chat_history(
|
||||
# 检查参数
|
||||
if not keyword and not time_range:
|
||||
return "未指定查询参数(需要提供keyword或time_range之一)"
|
||||
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
|
||||
# 时间过滤条件
|
||||
if time_range:
|
||||
# 判断是时间点还是时间范围
|
||||
@@ -46,73 +42,71 @@ async def query_chat_history(
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_filter = (
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
|
||||
else:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_range)
|
||||
time_filter = (
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
|
||||
query = query.where(time_filter)
|
||||
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
|
||||
if not records:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
# 解析多个关键词(支持空格、逗号等分隔符)
|
||||
keywords_list = parse_keywords_string(keyword)
|
||||
if not keywords_list:
|
||||
keywords_list = [keyword.strip()] if keyword.strip() else []
|
||||
|
||||
|
||||
# 转换为小写以便匹配
|
||||
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
|
||||
|
||||
|
||||
if not keywords_lower:
|
||||
return "关键词为空"
|
||||
|
||||
|
||||
filtered_records = []
|
||||
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
|
||||
# 解析record中的keywords JSON
|
||||
record_keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
keywords_data = (
|
||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
)
|
||||
if isinstance(keywords_data, list):
|
||||
record_keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# 检查是否包含任意一个关键词(OR关系)
|
||||
matched = False
|
||||
for kw in keywords_lower:
|
||||
if (kw in theme or
|
||||
kw in summary or
|
||||
kw in original_text or
|
||||
any(kw in k for k in record_keywords_list)):
|
||||
if (
|
||||
kw in theme
|
||||
or kw in summary
|
||||
or kw in original_text
|
||||
or any(kw in k for k in record_keywords_list)
|
||||
):
|
||||
matched = True
|
||||
break
|
||||
|
||||
|
||||
if matched:
|
||||
filtered_records.append(record)
|
||||
|
||||
|
||||
if not filtered_records:
|
||||
keywords_str = "、".join(keywords_list)
|
||||
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
|
||||
|
||||
|
||||
records = filtered_records
|
||||
|
||||
# 对即将返回的记录增加使用计数
|
||||
@@ -123,22 +117,23 @@ async def query_chat_history(
|
||||
record.count = (record.count or 0) + 1
|
||||
except Exception as update_error:
|
||||
logger.error(f"更新聊天记录概述计数失败: {update_error}")
|
||||
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records_to_use: # 最多返回3条记录
|
||||
result_parts = []
|
||||
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
@@ -147,18 +142,18 @@ async def query_chat_history(
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
if len(records) > len(records_to_use):
|
||||
omitted_count = len(records) - len(records_to_use)
|
||||
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
|
||||
return response_text
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -174,14 +169,14 @@ def register_tool():
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘' 或 '麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
|
||||
"required": False
|
||||
"required": False,
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
|
||||
"required": False
|
||||
}
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
execute_func=query_chat_history,
|
||||
)
|
||||
|
||||
@@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_jargon(
|
||||
keyword: str,
|
||||
chat_id: str
|
||||
) -> str:
|
||||
async def query_jargon(keyword: str, chat_id: str) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
|
||||
Args:
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
@@ -26,29 +23,17 @@ async def query_jargon(
|
||||
content = str(keyword).strip()
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
|
||||
# 先尝试精确匹配
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=chat_id,
|
||||
limit=10,
|
||||
case_sensitive=False,
|
||||
fuzzy=False
|
||||
)
|
||||
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not results:
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=chat_id,
|
||||
limit=10,
|
||||
case_sensitive=False,
|
||||
fuzzy=True
|
||||
)
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
|
||||
if results:
|
||||
# 如果是模糊匹配,显示找到的实际jargon内容
|
||||
if is_fuzzy_match:
|
||||
@@ -71,11 +56,11 @@ async def query_jargon(
|
||||
output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0]
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果")
|
||||
return output
|
||||
|
||||
|
||||
# 未命中
|
||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询jargon失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
@@ -86,14 +71,6 @@ def register_tool():
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(黑话/俚语/缩写)",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_jargon
|
||||
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
|
||||
execute_func=query_jargon,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,17 +11,13 @@ logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]]
|
||||
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
):
|
||||
"""
|
||||
初始化工具
|
||||
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
@@ -32,7 +28,7 @@ class MemoryRetrievalTool:
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
|
||||
def get_tool_description(self) -> str:
|
||||
"""获取工具的文本描述,用于prompt"""
|
||||
param_descriptions = []
|
||||
@@ -43,10 +39,10 @@ class MemoryRetrievalTool:
|
||||
required = param.get("required", True)
|
||||
required_str = "必填" if required else "可选"
|
||||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||
|
||||
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
return await self.execute_func(**kwargs)
|
||||
@@ -54,30 +50,30 @@ class MemoryRetrievalTool:
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
"""注册工具"""
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||
"""获取工具"""
|
||||
return self.tools.get(name)
|
||||
|
||||
|
||||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||
"""获取所有工具"""
|
||||
return self.tools.copy()
|
||||
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""获取所有工具的描述,用于prompt"""
|
||||
descriptions = []
|
||||
for i, tool in enumerate(self.tools.values(), 1):
|
||||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
|
||||
def get_action_types_list(self) -> str:
|
||||
"""获取所有动作类型的列表,用于prompt"""
|
||||
action_types = [tool.name for tool in self.tools.values()]
|
||||
@@ -91,13 +87,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]]
|
||||
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数
|
||||
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
@@ -111,4 +104,3 @@ def register_memory_retrieval_tool(
|
||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||
"""获取工具注册器实例"""
|
||||
return _tool_registry
|
||||
|
||||
|
||||
@@ -40,25 +40,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
Reference in New Issue
Block a user