feat:添加ReAct记忆提取系统
This commit is contained in:
155
src/memory_system/retrieval_tools/README.md
Normal file
155
src/memory_system/retrieval_tools/README.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# 记忆检索工具模块
|
||||
|
||||
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
retrieval_tools/
|
||||
├── __init__.py # 模块导出
|
||||
├── tool_registry.py # 工具注册系统
|
||||
├── tool_utils.py # 工具函数库(共用函数)
|
||||
├── query_jargon.py # 查询jargon工具
|
||||
├── query_chat_history.py # 查询聊天历史工具
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 模块说明
|
||||
|
||||
### `tool_registry.py`
|
||||
包含工具注册系统的核心类:
|
||||
- `MemoryRetrievalTool`: 工具基类
|
||||
- `MemoryRetrievalToolRegistry`: 工具注册器
|
||||
- `register_memory_retrieval_tool()`: 便捷注册函数
|
||||
- `get_tool_registry()`: 获取注册器实例
|
||||
|
||||
### `tool_utils.py`
|
||||
包含所有工具共用的工具函数:
|
||||
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
|
||||
- `parse_time_range()`: 解析时间范围字符串
|
||||
|
||||
### 工具文件
|
||||
每个工具都有独立的文件:
|
||||
- `query_jargon.py`: 根据关键词在jargon库中查询
|
||||
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索)
|
||||
|
||||
## 如何添加新工具
|
||||
|
||||
1. 创建新的工具文件,例如 `query_new_tool.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
新工具 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
|
||||
"""新工具的实现
|
||||
|
||||
Args:
|
||||
param1: 参数1
|
||||
param2: 参数2
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 实现逻辑
|
||||
return "结果"
|
||||
except Exception as e:
|
||||
logger.error(f"新工具执行失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_new_tool",
|
||||
description="新工具的描述",
|
||||
parameters=[
|
||||
{
|
||||
"name": "param1",
|
||||
"type": "string",
|
||||
"description": "参数1的描述",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "param2",
|
||||
"type": "string",
|
||||
"description": "参数2的描述",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_new_tool
|
||||
)
|
||||
```
|
||||
|
||||
2. 在 `__init__.py` 中导入并注册新工具:
|
||||
|
||||
```python
|
||||
from .query_new_tool import register_tool as register_query_new_tool
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_new_tool() # 添加新工具
|
||||
```
|
||||
|
||||
3. 工具会自动:
|
||||
- 出现在 ReAct Agent 的 prompt 中
|
||||
- 在动作类型列表中可用
|
||||
- 被 ReAct Agent 自动调用
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
|
||||
|
||||
# 初始化所有工具
|
||||
init_all_tools()
|
||||
|
||||
# 获取工具注册器
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 获取特定工具
|
||||
tool = registry.get_tool("query_chat_history")
|
||||
|
||||
# 执行工具(查询时间点事件)
|
||||
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
|
||||
|
||||
# 或者查询关键词
|
||||
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
|
||||
|
||||
# 或者查询时间范围
|
||||
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
|
||||
```
|
||||
|
||||
## 现有工具说明
|
||||
|
||||
### query_jargon
|
||||
根据关键词在jargon库中查询黑话/俚语/缩写的含义
|
||||
- 参数:`keyword` (必填) - 关键词
|
||||
|
||||
### query_chat_history
|
||||
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息
|
||||
- 参数:
|
||||
- `keyword` (可选) - 关键词,用于搜索消息内容
|
||||
- `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一)
|
||||
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一)
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有工具函数必须是异步函数(`async def`)
|
||||
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
|
||||
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
|
||||
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
|
||||
- 共用函数放在 `tool_utils.py` 中,避免代码重复
|
||||
|
||||
30
src/memory_system/retrieval_tools/__init__.py
Normal file
30
src/memory_system/retrieval_tools/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""
|
||||
记忆检索工具模块
|
||||
提供统一的工具注册和管理系统
|
||||
"""
|
||||
|
||||
from .tool_registry import (
|
||||
MemoryRetrievalTool,
|
||||
MemoryRetrievalToolRegistry,
|
||||
register_memory_retrieval_tool,
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_jargon import register_tool as register_query_jargon
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryRetrievalTool",
|
||||
"MemoryRetrievalToolRegistry",
|
||||
"register_memory_retrieval_tool",
|
||||
"get_tool_registry",
|
||||
"init_all_tools",
|
||||
]
|
||||
221
src/memory_system/retrieval_tools/query_chat_history.py
Normal file
221
src/memory_system/retrieval_tools/query_chat_history.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
根据时间或关键词在chat_history中查询 - 工具实现
|
||||
从ChatHistory表的聊天记录概述库中查询
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.common.database.database_model import ChatHistory
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp, parse_time_range
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_chat_history(
|
||||
chat_id: str,
|
||||
keyword: Optional[str] = None,
|
||||
time_point: Optional[str] = None,
|
||||
time_range: Optional[str] = None
|
||||
) -> str:
|
||||
"""根据时间或关键词在chat_history表中查询聊天记录概述
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
keyword: 关键词(可选)
|
||||
time_point: 时间点,格式:YYYY-MM-DD HH:MM:SS(可选)
|
||||
time_range: 时间范围,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"(可选)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 检查参数
|
||||
if not keyword and not time_point and not time_range:
|
||||
return "未指定查询参数(需要提供keyword、time_point或time_range之一)"
|
||||
|
||||
# 构建查询条件
|
||||
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
|
||||
|
||||
# 时间过滤条件
|
||||
time_conditions = []
|
||||
if time_point:
|
||||
# 时间点:查询包含该时间点的记录(start_time <= time_point <= end_time)
|
||||
target_timestamp = parse_datetime_to_timestamp(time_point)
|
||||
time_conditions.append(
|
||||
(ChatHistory.start_time <= target_timestamp) &
|
||||
(ChatHistory.end_time >= target_timestamp)
|
||||
)
|
||||
elif time_range:
|
||||
# 时间范围:查询与时间范围有交集的记录
|
||||
start_timestamp, end_timestamp = parse_time_range(time_range)
|
||||
# 交集条件:start_time < end_timestamp AND end_time > start_timestamp
|
||||
time_conditions.append(
|
||||
(ChatHistory.start_time < end_timestamp) &
|
||||
(ChatHistory.end_time > start_timestamp)
|
||||
)
|
||||
|
||||
if time_conditions:
|
||||
# 合并所有时间条件(OR关系)
|
||||
time_filter = time_conditions[0]
|
||||
for condition in time_conditions[1:]:
|
||||
time_filter = time_filter | condition
|
||||
query = query.where(time_filter)
|
||||
|
||||
# 执行查询
|
||||
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
|
||||
|
||||
if not records:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
# 如果有关键词,进一步过滤
|
||||
if keyword:
|
||||
keyword_lower = keyword.lower()
|
||||
filtered_records = []
|
||||
|
||||
for record in records:
|
||||
# 在theme、keywords、summary、original_text中搜索
|
||||
theme = (record.theme or "").lower()
|
||||
summary = (record.summary or "").lower()
|
||||
original_text = (record.original_text or "").lower()
|
||||
|
||||
# 解析keywords JSON
|
||||
keywords_list = []
|
||||
if record.keywords:
|
||||
try:
|
||||
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||
if isinstance(keywords_data, list):
|
||||
keywords_list = [str(k).lower() for k in keywords_data]
|
||||
except (json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
|
||||
# 检查是否包含关键词
|
||||
if (keyword_lower in theme or
|
||||
keyword_lower in summary or
|
||||
keyword_lower in original_text or
|
||||
any(keyword_lower in k for k in keywords_list)):
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
return f"未找到包含关键词'{keyword}'的聊天记录概述"
|
||||
|
||||
records = filtered_records
|
||||
|
||||
# 构建结果文本
|
||||
results = []
|
||||
for record in records[:10]: # 最多返回10条记录
|
||||
result_parts = []
|
||||
|
||||
# 添加主题
|
||||
if record.theme:
|
||||
result_parts.append(f"主题:{record.theme}")
|
||||
|
||||
# 添加时间范围
|
||||
from datetime import datetime
|
||||
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
result_parts.append(f"时间:{start_str} - {end_str}")
|
||||
|
||||
# 添加概括(优先使用summary,如果没有则使用original_text的前200字符)
|
||||
if record.summary:
|
||||
result_parts.append(f"概括:{record.summary}")
|
||||
elif record.original_text:
|
||||
text_preview = record.original_text[:200]
|
||||
if len(record.original_text) > 200:
|
||||
text_preview += "..."
|
||||
result_parts.append(f"内容:{text_preview}")
|
||||
|
||||
results.append("\n".join(result_parts))
|
||||
|
||||
if not results:
|
||||
return "未找到相关聊天记录概述"
|
||||
|
||||
# 如果只有一条记录,直接返回
|
||||
if len(results) == 1:
|
||||
return results[0]
|
||||
|
||||
# 多条记录,使用LLM总结
|
||||
try:
|
||||
llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="chat_history_analysis"
|
||||
)
|
||||
|
||||
query_desc = []
|
||||
if keyword:
|
||||
query_desc.append(f"关键词:{keyword}")
|
||||
if time_point:
|
||||
query_desc.append(f"时间点:{time_point}")
|
||||
if time_range:
|
||||
query_desc.append(f"时间范围:{time_range}")
|
||||
|
||||
query_info = ",".join(query_desc) if query_desc else "聊天记录概述"
|
||||
|
||||
combined_results = "\n\n---\n\n".join(results)
|
||||
|
||||
analysis_prompt = f"""请根据以下聊天记录概述,总结与查询条件相关的信息。请输出一段平文本,不要有特殊格式。
|
||||
查询条件:{query_info}
|
||||
|
||||
聊天记录概述:
|
||||
{combined_results}
|
||||
|
||||
请仔细分析聊天记录概述,提取与查询条件相关的信息并给出总结。如果概述中没有相关信息,输出"无有效信息"即可,不要输出其他内容。
|
||||
|
||||
总结:"""
|
||||
|
||||
response, (reasoning, model_name, tool_calls) = await llm_request.generate_response_async(
|
||||
prompt=analysis_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
logger.info(f"查询聊天历史概述提示词: {analysis_prompt}")
|
||||
logger.info(f"查询聊天历史概述响应: {response}")
|
||||
logger.info(f"查询聊天历史概述推理: {reasoning}")
|
||||
logger.info(f"查询聊天历史概述模型: {model_name}")
|
||||
|
||||
if "无有效信息" in response:
|
||||
return "无有效信息"
|
||||
|
||||
return response
|
||||
|
||||
except Exception as llm_error:
|
||||
logger.error(f"LLM分析聊天记录概述失败: {llm_error}")
|
||||
# 如果LLM分析失败,返回前3条记录的摘要
|
||||
return "\n\n---\n\n".join(results[:3])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询聊天历史概述失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_chat_history",
|
||||
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(可选,用于在主题、关键词、概括、原文中搜索)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "time_point",
|
||||
"type": "string",
|
||||
"description": "时间点,格式:YYYY-MM-DD HH:MM:SS(可选,与time_range二选一)。用于查询包含该时间点的聊天记录概述",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "time_range",
|
||||
"type": "string",
|
||||
"description": "时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(可选,与time_point二选一)。用于查询与时间范围有交集的聊天记录概述",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
execute_func=query_chat_history
|
||||
)
|
||||
92
src/memory_system/retrieval_tools/query_jargon.py
Normal file
92
src/memory_system/retrieval_tools/query_jargon.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
根据关键词在jargon库中查询 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_jargon(
|
||||
keyword: str,
|
||||
chat_id: str,
|
||||
fuzzy: bool = False,
|
||||
search_all: bool = False
|
||||
) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
Args:
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
fuzzy: 是否使用模糊搜索,默认False(精确匹配)
|
||||
search_all: 是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global)
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(keyword).strip()
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
# 根据参数执行搜索
|
||||
search_chat_id = None if search_all else chat_id
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=search_chat_id,
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=fuzzy
|
||||
)
|
||||
|
||||
if results:
|
||||
result = results[0]
|
||||
translation = result.get("translation", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||
search_scope = "全库" if search_all else "当前会话或全局"
|
||||
output = f"“{content}可能为黑话或者网络简写,翻译为:{translation},含义为:{meaning}”"
|
||||
logger.info(f"在jargon库中找到匹配({search_scope},{search_type}): {content}")
|
||||
return output
|
||||
|
||||
# 未命中
|
||||
search_type = "模糊搜索" if fuzzy else "精确匹配"
|
||||
search_scope = "全库" if search_all else "当前会话或全局"
|
||||
logger.info(f"在jargon库中未找到匹配({search_scope},{search_type}): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询jargon失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索和模糊搜索。默认优先搜索当前会话或全局jargon,可以设置为搜索全库。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "keyword",
|
||||
"type": "string",
|
||||
"description": "关键词(黑话/俚语/缩写),支持模糊搜索",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "fuzzy",
|
||||
"type": "boolean",
|
||||
"description": "是否使用模糊搜索(部分匹配),默认False(精确匹配)。当精确匹配找不到时,可以尝试使用模糊搜索。",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "search_all",
|
||||
"type": "boolean",
|
||||
"description": "是否搜索全库(不限chat_id),默认False(仅搜索当前会话或global的jargon)。当在当前会话中找不到时,可以尝试搜索全库。",
|
||||
"required": False
|
||||
}
|
||||
],
|
||||
execute_func=query_jargon
|
||||
)
|
||||
|
||||
114
src/memory_system/retrieval_tools/tool_registry.py
Normal file
114
src/memory_system/retrieval_tools/tool_registry.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
工具注册系统
|
||||
提供统一的工具注册和管理接口
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Callable, Awaitable
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]]
|
||||
):
|
||||
"""
|
||||
初始化工具
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表,格式:[{"name": "param_name", "type": "string", "description": "参数描述", "required": True}]
|
||||
execute_func: 执行函数,必须是异步函数
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_description(self) -> str:
|
||||
"""获取工具的文本描述,用于prompt"""
|
||||
param_descriptions = []
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
required = param.get("required", True)
|
||||
required_str = "必填" if required else "可选"
|
||||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
async def execute(self, **kwargs) -> str:
|
||||
"""执行工具"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
def __init__(self):
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
"""注册工具"""
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||
"""获取工具"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||
"""获取所有工具"""
|
||||
return self.tools.copy()
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""获取所有工具的描述,用于prompt"""
|
||||
descriptions = []
|
||||
for i, tool in enumerate(self.tools.values(), 1):
|
||||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def get_action_types_list(self) -> str:
|
||||
"""获取所有动作类型的列表,用于prompt"""
|
||||
action_types = [tool.name for tool in self.tools.values()]
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
|
||||
# 全局工具注册器实例
|
||||
_tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]]
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数
|
||||
|
||||
Args:
|
||||
name: 工具名称
|
||||
description: 工具描述
|
||||
parameters: 参数定义列表
|
||||
execute_func: 执行函数
|
||||
"""
|
||||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||
_tool_registry.register_tool(tool)
|
||||
|
||||
|
||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||
"""获取工具注册器实例"""
|
||||
return _tool_registry
|
||||
|
||||
64
src/memory_system/retrieval_tools/tool_utils.py
Normal file
64
src/memory_system/retrieval_tools/tool_utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
工具函数库
|
||||
包含所有工具共用的工具函数
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
|
||||
|
||||
def parse_time_range(time_range: str) -> Tuple[float, float]:
|
||||
"""
|
||||
解析时间范围字符串,返回开始和结束时间戳
|
||||
|
||||
Args:
|
||||
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
|
||||
|
||||
Returns:
|
||||
Tuple[float, float]: (开始时间戳, 结束时间戳)
|
||||
"""
|
||||
if " - " not in time_range:
|
||||
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
|
||||
|
||||
parts = time_range.split(" - ", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"时间范围格式错误: {time_range}")
|
||||
|
||||
start_str = parts[0].strip()
|
||||
end_str = parts[1].strip()
|
||||
|
||||
start_timestamp = parse_datetime_to_timestamp(start_str)
|
||||
end_timestamp = parse_datetime_to_timestamp(end_str)
|
||||
|
||||
return start_timestamp, end_timestamp
|
||||
|
||||
Reference in New Issue
Block a user