This commit is contained in:
SengokuCola
2025-11-13 19:00:59 +08:00
46 changed files with 1000 additions and 1041 deletions

View File

@@ -9,16 +9,13 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_jargon(
keyword: str,
chat_id: str
) -> str:
async def query_jargon(keyword: str, chat_id: str) -> str:
"""根据关键词在jargon库中查询
Args:
keyword: 关键词(黑话/俚语/缩写)
chat_id: 聊天ID
Returns:
str: 查询结果
"""
@@ -26,29 +23,17 @@ async def query_jargon(
content = str(keyword).strip()
if not content:
return "关键词为空"
# 先尝试精确匹配
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=False
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not results:
results = search_jargon(
keyword=content,
chat_id=chat_id,
limit=10,
case_sensitive=False,
fuzzy=True
)
results = search_jargon(keyword=content, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if results:
# 如果是模糊匹配显示找到的实际jargon内容
if is_fuzzy_match:
@@ -71,11 +56,11 @@ async def query_jargon(
output = "".join(output_parts) if len(output_parts) > 1 else output_parts[0]
logger.info(f"在jargon库中找到匹配当前会话或全局精确匹配: {content},找到{len(results)}条结果")
return output
# 未命中
logger.info(f"在jargon库中未找到匹配当前会话或全局精确匹配和模糊搜索都未找到: {content}")
return f"未在jargon库中找到'{content}'的解释"
except Exception as e:
logger.error(f"查询jargon失败: {e}")
return f"查询失败: {str(e)}"
@@ -86,14 +71,6 @@ def register_tool():
register_memory_retrieval_tool(
name="query_jargon",
description="根据关键词在jargon库中查询黑话/俚语/缩写的含义。支持大小写不敏感搜索默认会先尝试精确匹配如果找不到则自动使用模糊搜索。仅搜索当前会话或全局jargon。",
parameters=[
{
"name": "keyword",
"type": "string",
"description": "关键词(黑话/俚语/缩写)",
"required": True
}
],
execute_func=query_jargon
parameters=[{"name": "keyword", "type": "string", "description": "关键词(黑话/俚语/缩写)", "required": True}],
execute_func=query_jargon,
)

View File

@@ -12,17 +12,13 @@ logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
self, name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
):
"""
初始化工具
Args:
name: 工具名称
description: 工具描述
@@ -33,7 +29,7 @@ class MemoryRetrievalTool:
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
@@ -44,10 +40,10 @@ class MemoryRetrievalTool:
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
@@ -97,10 +93,10 @@ class MemoryRetrievalTool:
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self):
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
@@ -108,22 +104,22 @@ class MemoryRetrievalToolRegistry:
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
@@ -145,13 +141,10 @@ _tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]]
name: str, description: str, parameters: List[Dict[str, Any]], execute_func: Callable[..., Awaitable[str]]
) -> None:
"""注册记忆检索工具的便捷函数
Args:
name: 工具名称
description: 工具描述
@@ -165,4 +158,3 @@ def register_memory_retrieval_tool(
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry