feat:增加记忆提取能力
This commit is contained in:
@@ -1,161 +0,0 @@
|
||||
# 记忆检索工具模块
|
||||
|
||||
这个模块提供了统一的工具注册和管理系统,用于记忆检索功能。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
retrieval_tools/
|
||||
├── __init__.py # 模块导出
|
||||
├── tool_registry.py # 工具注册系统
|
||||
├── tool_utils.py # 工具函数库(共用函数)
|
||||
├── query_jargon.py # 查询jargon工具
|
||||
├── query_chat_history.py # 查询聊天历史工具
|
||||
├── query_lpmm_knowledge.py # 查询LPMM知识库工具
|
||||
└── README.md # 本文件
|
||||
```
|
||||
|
||||
## 模块说明
|
||||
|
||||
### `tool_registry.py`
|
||||
包含工具注册系统的核心类:
|
||||
- `MemoryRetrievalTool`: 工具基类
|
||||
- `MemoryRetrievalToolRegistry`: 工具注册器
|
||||
- `register_memory_retrieval_tool()`: 便捷注册函数
|
||||
- `get_tool_registry()`: 获取注册器实例
|
||||
|
||||
### `tool_utils.py`
|
||||
包含所有工具共用的工具函数:
|
||||
- `parse_datetime_to_timestamp()`: 解析时间字符串为时间戳
|
||||
- `parse_time_range()`: 解析时间范围字符串
|
||||
|
||||
### 工具文件
|
||||
每个工具都有独立的文件:
|
||||
- `query_jargon.py`: 根据关键词在jargon库中查询
|
||||
- `query_chat_history.py`: 根据时间或关键词在chat_history中查询(支持查询时间点事件、时间范围事件、关键词搜索)
|
||||
|
||||
## 如何添加新工具
|
||||
|
||||
1. 创建新的工具文件,例如 `query_new_tool.py`:
|
||||
|
||||
```python
|
||||
"""
|
||||
新工具 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
from .tool_utils import parse_datetime_to_timestamp # 如果需要使用工具函数
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_new_tool(param1: str, param2: str, chat_id: str) -> str:
|
||||
"""新工具的实现
|
||||
|
||||
Args:
|
||||
param1: 参数1
|
||||
param2: 参数2
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
# 实现逻辑
|
||||
return "结果"
|
||||
except Exception as e:
|
||||
logger.error(f"新工具执行失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_new_tool",
|
||||
description="新工具的描述",
|
||||
parameters=[
|
||||
{
|
||||
"name": "param1",
|
||||
"type": "string",
|
||||
"description": "参数1的描述",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "param2",
|
||||
"type": "string",
|
||||
"description": "参数2的描述",
|
||||
"required": True
|
||||
}
|
||||
],
|
||||
execute_func=query_new_tool
|
||||
)
|
||||
```
|
||||
|
||||
2. 在 `__init__.py` 中导入并注册新工具:
|
||||
|
||||
```python
|
||||
from .query_new_tool import register_tool as register_query_new_tool
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_new_tool() # 添加新工具
|
||||
```
|
||||
|
||||
3. 工具会自动:
|
||||
- 出现在 ReAct Agent 的 prompt 中
|
||||
- 在动作类型列表中可用
|
||||
- 被 ReAct Agent 自动调用
|
||||
|
||||
## 使用示例
|
||||
|
||||
```python
|
||||
from src.memory_system.retrieval_tools import init_all_tools, get_tool_registry
|
||||
|
||||
# 初始化所有工具
|
||||
init_all_tools()
|
||||
|
||||
# 获取工具注册器
|
||||
registry = get_tool_registry()
|
||||
|
||||
# 获取特定工具
|
||||
tool = registry.get_tool("query_chat_history")
|
||||
|
||||
# 执行工具(查询时间点事件)
|
||||
result = await tool.execute(time_point="2025-01-15 14:30:00", chat_id="chat123")
|
||||
|
||||
# 或者查询关键词
|
||||
result = await tool.execute(keyword="小丑AI", chat_id="chat123")
|
||||
|
||||
# 或者查询时间范围
|
||||
result = await tool.execute(time_range="2025-01-15 10:00:00 - 2025-01-15 20:00:00", chat_id="chat123")
|
||||
```
|
||||
|
||||
## 现有工具说明
|
||||
|
||||
### query_jargon
|
||||
根据关键词在jargon库中查询黑话/俚语/缩写的含义
|
||||
- 参数:`keyword` (必填) - 关键词
|
||||
|
||||
### query_chat_history
|
||||
根据时间或关键词在chat_history中查询相关聊天记录。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息
|
||||
- 参数:
|
||||
- `keyword` (可选) - 关键词,用于搜索消息内容
|
||||
- `time_point` (可选) - 时间点,格式:YYYY-MM-DD HH:MM:SS,用于查询某个时间点附近发生了什么(与time_range二选一)
|
||||
- `time_range` (可选) - 时间范围,格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(与time_point二选一)
|
||||
|
||||
### query_lpmm_knowledge
|
||||
从LPMM知识库中检索与关键词相关的知识内容
|
||||
- 参数:
|
||||
- `query` (必填) - 查询的关键词或问题描述
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有工具函数必须是异步函数(`async def`)
|
||||
- 如果工具函数签名需要 `chat_id` 参数,系统会自动添加(通过函数签名检测)
|
||||
- 工具参数定义中的 `required` 字段用于生成 prompt 描述
|
||||
- 工具执行失败时应返回错误信息字符串,而不是抛出异常
|
||||
- 共用函数放在 `tool_utils.py` 中,避免代码重复
|
||||
|
||||
@@ -11,7 +11,6 @@ from .tool_registry import (
|
||||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_jargon import register_tool as register_query_jargon
|
||||
from .query_chat_history import register_tool as register_query_chat_history
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_person_info import register_tool as register_query_person_info
|
||||
@@ -20,7 +19,6 @@ from src.config.config import global_config
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_jargon()
|
||||
register_query_chat_history()
|
||||
register_query_person_info()
|
||||
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
根据关键词在jargon库中查询 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.jargon.jargon_miner import search_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_jargon(keyword: str, chat_id: str) -> str:
|
||||
"""根据关键词在jargon库中查询
|
||||
|
||||
Args:
|
||||
keyword: 关键词(黑话/俚语/缩写)
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(keyword).strip()
|
||||
if not content:
|
||||
return "关键词为空"
|
||||
|
||||
# 先尝试精确匹配
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not results:
|
||||
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if results:
|
||||
# 如果是模糊匹配,显示找到的实际jargon内容
|
||||
if is_fuzzy_match:
|
||||
# 处理多个结果
|
||||
output_parts = [f"未精确匹配到'{content}'"]
|
||||
for result in results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
output = ",".join(output_parts)
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,模糊搜索): {content},找到{len(results)}条结果")
|
||||
else:
|
||||
# 精确匹配,可能有多条(相同content但不同chat_id的情况)
|
||||
output_parts = []
|
||||
for result in results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{content}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
output = ";".join(output_parts) if len(output_parts) > 1 else output_parts[0]
|
||||
logger.info(f"在jargon库中找到匹配(当前会话或全局,精确匹配): {content},找到{len(results)}条结果")
|
||||
return output
|
||||
|
||||
# 未命中
|
||||
logger.info(f"在jargon库中未找到匹配(当前会话或全局,精确匹配和模糊搜索都未找到): {content}")
|
||||
return f"未在jargon库中找到'{content}'的解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询jargon失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_jargon",
|
||||
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索,默认会先尝试精确匹配,如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
|
||||
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
|
||||
execute_func=query_jargon,
|
||||
)
|
||||
@@ -12,6 +12,25 @@ from .tool_registry import register_memory_retrieval_tool
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
def _calculate_similarity(query: str, target: str) -> float:
|
||||
"""计算查询词在目标字符串中的相似度比例
|
||||
|
||||
Args:
|
||||
query: 查询词
|
||||
target: 目标字符串
|
||||
|
||||
Returns:
|
||||
float: 相似度比例(0.0-1.0),查询词长度 / 目标字符串长度
|
||||
"""
|
||||
if not query or not target:
|
||||
return 0.0
|
||||
query_len = len(query)
|
||||
target_len = len(target)
|
||||
if target_len == 0:
|
||||
return 0.0
|
||||
return query_len / target_len
|
||||
|
||||
|
||||
def _format_group_nick_names(group_nick_name_field) -> str:
|
||||
"""格式化群昵称信息
|
||||
|
||||
@@ -81,11 +100,29 @@ async def query_person_info(person_name: str) -> str:
|
||||
if not records:
|
||||
return f"未找到模糊匹配'{person_name}'的用户信息"
|
||||
|
||||
# 根据相似度过滤结果:查询词在目标字符串中至少占50%
|
||||
SIMILARITY_THRESHOLD = 0.5
|
||||
filtered_records = []
|
||||
for record in records:
|
||||
if not record.person_name:
|
||||
continue
|
||||
# 精确匹配总是保留(相似度100%)
|
||||
if record.person_name.strip() == person_name:
|
||||
filtered_records.append(record)
|
||||
else:
|
||||
# 模糊匹配需要检查相似度
|
||||
similarity = _calculate_similarity(person_name, record.person_name.strip())
|
||||
if similarity >= SIMILARITY_THRESHOLD:
|
||||
filtered_records.append(record)
|
||||
|
||||
if not filtered_records:
|
||||
return f"未找到相似度≥50%的匹配'{person_name}'的用户信息"
|
||||
|
||||
# 区分精确匹配和模糊匹配的结果
|
||||
exact_matches = []
|
||||
fuzzy_matches = []
|
||||
|
||||
for record in records:
|
||||
for record in filtered_records:
|
||||
# 检查是否是精确匹配
|
||||
if record.person_name and record.person_name.strip() == person_name:
|
||||
exact_matches.append(record)
|
||||
@@ -248,7 +285,7 @@ async def query_person_info(person_name: str) -> str:
|
||||
response_text = "\n\n---\n\n".join(results)
|
||||
|
||||
# 添加统计信息
|
||||
total_count = len(records)
|
||||
total_count = len(filtered_records)
|
||||
exact_count = len(exact_matches)
|
||||
fuzzy_count = len(fuzzy_matches)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user