feat:添加ReAct记忆提取系统

This commit is contained in:
SengokuCola
2025-11-09 14:02:29 +08:00
parent d761d42dd7
commit 7a3f260cc3
12 changed files with 1463 additions and 45 deletions

View File

@@ -0,0 +1,155 @@
# 记忆检索工具模块
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
## 目录结构
```
retrieval_tools/
├── __init__.py # 模块导出
├── tool_registry.py # 工具注册系统
├── tool_utils.py # 工具函数库(共用函数)
├── query_jargon.py # 查询jargon工具
├── query_chat_history.py # 查询聊天历史工具
└── README.md # 本文件
```
## 模块说明
### `tool_registry.py`
包含工具注册系统的核心类:
- `MemoryRetrievalTool`: 工具基类
- `MemoryRetrievalToolRegistry`: 工具注册器
- `register_memory_retrieval_tool()`: 便捷注册函数
- `get_tool_registry()`: 获取注册器实例
### `tool_utils.py`
包含所有工具共用的工具函数:
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
- `parse_time_range()`: 解析时间范围字符串
### 工具文件
每个工具都有独立的文件:
- `query_jargon.py`: 根据关键词在jargon库中查询
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询支持查询时间点事件、时间范围事件、关键词搜索
## 如何添加新工具
1. 创建新的工具文件,例如 `query_new_tool.py`
```python
"""
新工具 - 工具实现
"""
from src.common.logger import get_logger
from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
logger = get_logger("memory_retrieval_tools")
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
"""新工具的实现
Args:
param1: 参数1
param2: 参数2
chat_id: 聊天ID
Returns:
str: 查询结果
"""
try:
# 实现逻辑
return "结果"
except Exception as e:
logger.error(f"新工具执行失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_new_tool",
description="新工具的描述",
parameters=[
{
"name": "param1",
"type": "string",
"description": "参数1的描述",
"required": True
},
{
"name": "param2",
"type": "string",
"description": "参数2的描述",
"required": True
}
],
execute_func=query_new_tool
)
```
2.`__init__.py` 中导入并注册新工具:
```python
from .query_new_tool import register_tool as register_query_new_tool
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()
register_query_chat_history()
register_query_new_tool() # 添加新工具
```
3. 工具会自动:
- 出现在 ReAct Agent 的 prompt 中
- 在动作类型列表中可用
- 被 ReAct Agent 自动调用
## 使用示例
```python
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
# 初始化所有工具
init_all_tools()
# 获取工具注册器
registry = get_tool_registry()
# 获取特定工具
tool = registry.get_tool("query_chat_history")
# 执行工具(查询时间点事件)
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
# 或者查询关键词
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
# 或者查询时间范围
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
```
## 现有工具说明
### query_jargon
根据关键词在jargon库中查询黑话/俚语/缩写的含义
- 参数:`keyword` (必填) - 关键词
### query_chat_history
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件或根据关键词搜索消息
- 参数:
- `keyword` (可选) - 关键词,用于搜索消息内容
- `time_point` (可选) - 时间点格式YYYY-MM-DD HH:MM:SS用于查询某个时间点附近发生了什么与time_range二选一
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'与time_point二选一
## 注意事项
- 所有工具函数必须是异步函数(`async def`
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
- 共用函数放在 `tool_utils.py` 中,避免代码重复

View File

@@ -0,0 +1,30 @@
"""
记忆检索工具模块
提供统一的工具注册和管理系统
"""
from .tool_registry import (
MemoryRetrievalTool,
MemoryRetrievalToolRegistry,
register_memory_retrieval_tool,
get_tool_registry,
)
# 导入所有工具的注册函数
from .query_jargon import register_tool as register_query_jargon
from .query_chat_history import register_tool as register_query_chat_history
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()
register_query_chat_history()
__all__ = [
"MemoryRetrievalTool",
"MemoryRetrievalToolRegistry",
"register_memory_retrieval_tool",
"get_tool_registry",
"init_all_tools",
]

View File

@@ -0,0 +1,221 @@
"""
根据时间或关键词在chat_history中查询 - 工具实现
从ChatHistory表的聊天记录概述库中查询
"""
import json
from typing import Optional
from src.common.logger import get_logger
from src.config.config import model_config
from src.common.database.database_model import ChatHistory
from src.llm_models.utils_model import LLMRequest
from .tool_registry import register_memory_retrieval_tool
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_point: Optional[str] = None,
time_range: Optional[str] = None
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词(可选)
time_point: 时间点格式YYYY-MM-DD HH:MM:SS可选
time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选)
Returns:
str: 查询结果
"""
try:
# 检查参数
if not keyword and not time_point and not time_range:
return "未指定查询参数需要提供keyword、time_point或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
time_conditions = []
if time_point:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_point)
time_conditions.append(
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
elif time_range:
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_conditions.append(
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
if time_conditions:
# 合并所有时间条件OR关系
time_filter = time_conditions[0]
for condition in time_conditions[1:]:
time_filter = time_filter | condition
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
if not records:
return "未找到相关聊天记录概述"
# 如果有关键词,进一步过滤
if keyword:
keyword_lower = keyword.lower()
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析keywords JSON
keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
if isinstance(keywords_data, list):
keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含关键词
if (keyword_lower in theme or
keyword_lower in summary or
keyword_lower in original_text or
any(keyword_lower in k for k in keywords_list)):
filtered_records.append(record)
if not filtered_records:
return f"未找到包含关键词'{keyword}'的聊天记录概述"
records = filtered_records
# 构建结果文本
results = []
for record in records[:10]: # 最多返回10条记录
result_parts = []
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
elif record.original_text:
text_preview = record.original_text[:200]
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
# 如果只有一条记录,直接返回
if len(results) == 1:
return results[0]
# 多条记录使用LLM总结
try:
llm_request = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="chat_history_analysis"
)
query_desc = []
if keyword:
query_desc.append(f"关键词:{keyword}")
if time_point:
query_desc.append(f"时间点:{time_point}")
if time_range:
query_desc.append(f"时间范围:{time_range}")
query_info = "".join(query_desc) if query_desc else "聊天记录概述"
combined_results = "\n\n---\n\n".join(results)
analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。
查询条件:{query_info}
聊天记录概述:
{combined_results}
请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。
总结:"""
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
prompt=analysis_prompt,
temperature=0.3,
max_tokens=512
)
logger.info(f"查询聊天历史概述提示词: {analysis_prompt}")
logger.info(f"查询聊天历史概述响应: {response}")
logger.info(f"查询聊天历史概述推理: {reasoning}")
logger.info(f"查询聊天历史概述模型: {model_name}")
if "无有效信息" in response:
return "无有效信息"
return response
except Exception as llm_error:
logger.error(f"LLM分析聊天记录概述失败: {llm_error}")
# 如果LLM分析失败返回前3条记录的摘要
return "\n\n---\n\n".join(results[:3])
except Exception as e:
logger.error(f"查询聊天历史概述失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_chat_history",
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件或根据关键词搜索消息概述",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
"required": False
},
{
"name": "time_point",
"type": "string",
"description": "时间点格式YYYY-MM-DD HH:MM:SS可选与time_range二选一。用于查询包含该时间点的聊天记录概述",
"required": False
},
{
"name": "time_range",
"type": "string",
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'可选与time_point二选一。用于查询与时间范围有交集的聊天记录概述",
"required": False
}
],
execute_func=query_chat_history
)

View File

@@ -0,0 +1,92 @@
"""
根据关键词在jargon库中查询 - 工具实现
"""
from src.common.logger import get_logger
from src.jargon.jargon_miner import search_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str,
fuzzy: bool = False,
search_all: bool = False
) -> str:
"""根据关键词在jargon库中查询
Args:
keyword: 关键词(黑话/俚语/缩写)
chat_id: 聊天ID
fuzzy: 是否使用模糊搜索默认False精确匹配
search_all: 是否搜索全库不限chat_id默认False仅搜索当前会话或global
Returns:
str: 查询结果
"""
try:
content = str(keyword).strip()
if not content:
return "关键词为空"
# 根据参数执行搜索
search_chat_id = None if search_all else chat_id
results = search_jargon(
keyword=content,
chat_id=search_chat_id,
limit=1,
case_sensitive=False,
fuzzy=fuzzy
)
if results:
result = results[0]
translation = result.get("translation", "").strip()
meaning = result.get("meaning", "").strip()
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
output = f"{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}"
logger.info(f"在jargon库中找到匹配{search_scope}{search_type}: {content}")
return output
# 未命中
search_type = "模糊搜索" if fuzzy else "精确匹配"
search_scope = "全库" if search_all else "当前会话或全局"
logger.info(f"在jargon库中未找到匹配{search_scope}{search_type}: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon可以设置为搜索全库。",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写),支持模糊搜索",
"required": True
},
{
"name": "fuzzy",
"type": "boolean",
"description": "是否使用模糊搜索部分匹配默认False精确匹配。当精确匹配找不到时可以尝试使用模糊搜索。",
"required": False
},
{
"name": "search_all",
"type": "boolean",
"description": "是否搜索全库不限chat_id默认False仅搜索当前会话或global的jargon。当在当前会话中找不到时可以尝试搜索全库。",
"required": False
}
],
execute_func=query_jargon
)

View File

@@ -0,0 +1,114 @@
"""
工具注册系统
提供统一的工具注册和管理接口
"""
from typing import List, Dict, Any, Optional, Callable, Awaitable
from src.common.logger import get_logger
logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
execute_func: 执行函数,必须是异步函数
"""
self.name = name
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
for param in self.parameters:
param_name = param.get("name", "")
param_type = param.get("type", "string")
param_desc = param.get("description", "")
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
# 全局工具注册器实例
_tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
parameters: 参数定义列表
execute_func: 执行函数
"""
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
_tool_registry.register_tool(tool)
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry

View File

@@ -0,0 +1,64 @@
"""
工具函数库
包含所有工具共用的工具函数
"""
from datetime import datetime
from typing import Tuple
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e:
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
Args:
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp