This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -14,20 +14,16 @@ from .tool_utils import parse_datetime_to_timestamp, parse_time_range
logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None
) -> str:
async def query_chat_history(chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔)
time_range: 时间范围或时间点,格式:
- 时间范围:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
- 时间点:"YYYY-MM-DD HH:MM:SS"(查询包含该时间点的记录)
Returns:
str: 查询结果
"""
@@ -35,10 +31,10 @@ async def query_chat_history(
# 检查参数
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
if time_range:
# 判断是时间点还是时间范围
@@ -46,73 +42,71 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
if not records:
return "未找到相关聊天记录概述"
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 检查是否包含任意一个关键词OR关系
matched = False
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
if matched:
filtered_records.append(record)
if not filtered_records:
keywords_str = "".join(keywords_list)
return f"未找到包含关键词'{keywords_str}'的聊天记录概述"
records = filtered_records
# 对即将返回的记录增加使用计数
@@ -123,22 +117,23 @@ async def query_chat_history(
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建结果文本
results = []
for record in records_to_use: # 最多返回3条记录
result_parts = []
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
@@ -147,18 +142,18 @@ async def query_chat_history(
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
response_text = "\n\n---\n\n".join(results)
if len(records) > len(records_to_use):
omitted_count = len(records) - len(records_to_use)
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
return response_text
except Exception as e:
logger.error(f"查询聊天历史概述失败: {e}")
return f"查询失败: {str(e)}"
@@ -174,14 +169,14 @@ def register_tool():
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索,只要包含任意一个关键词即匹配)",
"required": False
"required": False,
},
{
"name": "time_range",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
}
"required": False,
},
],
execute_func=query_chat_history
execute_func=query_chat_history,
)

View File

@@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str
) -> str:
async def query_jargon(keyword: str, chat_id: str) -> str:
"""根据关键词在jargon库中查询
Args:
keyword: 关键词(黑话/俚语/缩写)
chat_id: 聊天ID
Returns:
str: 查询结果
"""
@@ -26,29 +23,17 @@ async def query_jargon(
content = str(keyword).strip()
if not content:
return "关键词为空"
# 先尝试精确匹配
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not results:
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if results:
# 如果是模糊匹配显示找到的实际jargon内容
if is_fuzzy_match:
@@ -71,11 +56,11 @@ async def query_jargon(
output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0]
logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果")
return output
# 未命中
logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}"
@@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
execute_func=query_jargon,
)

View File

@@ -11,17 +11,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
@@ -32,7 +28,7 @@ class MemoryRetrievalTool:
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
@@ -43,10 +39,10 @@ class MemoryRetrievalTool:
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
@@ -54,30 +50,30 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt"""
action_types = [tool.name for tool in self.tools.values()]
@@ -91,13 +87,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
@@ -111,4 +104,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry

View File

@@ -40,25 +40,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串,返回开始和结束时间戳
Args:
time_range: 时间范围字符串,格式:"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp
return start_timestamp, end_timestamp