feat:显示实时占用上下文,移除旧记忆系统

This commit is contained in:
SengokuCola
2026-04-01 13:18:17 +08:00
parent 503a257d66
commit d713aa9576
11 changed files with 55 additions and 2663 deletions

View File

@@ -66,6 +66,9 @@ class LLMResponseResult(BaseDataModel):
reasoning: str = field(default_factory=str) reasoning: str = field(default_factory=str)
model_name: str = field(default_factory=str) model_name: str = field(default_factory=str)
tool_calls: List[ToolCall] | None = None tool_calls: List[ToolCall] | None = None
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
@dataclass(slots=True) @dataclass(slots=True)
@@ -120,6 +123,9 @@ class LLMServiceResult(BaseDataModel):
"response": self.completion.response, "response": self.completion.response,
"reasoning": self.completion.reasoning, "reasoning": self.completion.reasoning,
"model_name": self.completion.model_name, "model_name": self.completion.model_name,
"prompt_tokens": self.completion.prompt_tokens,
"completion_tokens": self.completion.completion_tokens,
"total_tokens": self.completion.total_tokens,
} }
if self.completion.tool_calls is not None: if self.completion.tool_calls is not None:
payload["tool_calls"] = [ payload["tool_calls"] = [

View File

@@ -34,6 +34,7 @@ from src.llm_models.model_client.base_client import (
ClientRequest, ClientRequest,
EmbeddingRequest, EmbeddingRequest,
ResponseRequest, ResponseRequest,
UsageRecord,
client_registry, client_registry,
) )
from src.llm_models.payload_content.message import Message, MessageBuilder from src.llm_models.payload_content.message import Message, MessageBuilder
@@ -137,6 +138,7 @@ class LLMOrchestrator:
reasoning_content: str, reasoning_content: str,
model_name: str, model_name: str,
tool_calls: List[ToolCall] | None, tool_calls: List[ToolCall] | None,
usage: UsageRecord | None = None,
) -> LLMResponseResult: ) -> LLMResponseResult:
"""构建统一的文本响应结果。 """构建统一的文本响应结果。
@@ -154,6 +156,9 @@ class LLMOrchestrator:
reasoning=reasoning_content, reasoning=reasoning_content,
model_name=model_name, model_name=model_name,
tool_calls=tool_calls, tool_calls=tool_calls,
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
completion_tokens=usage.completion_tokens if usage is not None else 0,
total_tokens=usage.total_tokens if usage is not None else 0,
) )
async def generate_response_for_image( async def generate_response_for_image(
@@ -215,7 +220,13 @@ class LLMOrchestrator:
endpoint="/chat/completions", endpoint="/chat/completions",
time_cost=time_cost, time_cost=time_cost,
) )
return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls) return self._build_generation_result(
content,
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult: async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult:
"""为语音生成转录响应。 """为语音生成转录响应。
@@ -298,7 +309,13 @@ class LLMOrchestrator:
endpoint="/chat/completions", endpoint="/chat/completions",
time_cost=time.time() - start_time, time_cost=time.time() - start_time,
) )
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) return self._build_generation_result(
content or "",
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def generate_response_with_message_async( async def generate_response_with_message_async(
self, self,
@@ -364,7 +381,13 @@ class LLMOrchestrator:
endpoint="/chat/completions", endpoint="/chat/completions",
time_cost=time_cost, time_cost=time_cost,
) )
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls) return self._build_generation_result(
content or "",
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult: async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult:
"""获取嵌入向量。 """获取嵌入向量。

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Sequence
import asyncio import asyncio
import json import json
import random import random
import re
from PIL import Image as PILImage from PIL import Image as PILImage
from pydantic import BaseModel, Field as PydanticField from pydantic import BaseModel, Field as PydanticField
@@ -28,7 +27,7 @@ from src.config.config import global_config
from src.core.tooling import ToolRegistry, ToolSpec from src.core.tooling import ToolRegistry, ToolSpec
from src.know_u.knowledge import extract_category_ids_from_result from src.know_u.knowledge import extract_category_ids_from_result
from src.llm_models.model_client.base_client import BaseClient from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.services.llm_service import LLMServiceClient from src.services.llm_service import LLMServiceClient
@@ -697,58 +696,29 @@ class MaisakaChatLoopService:
) )
@staticmethod @staticmethod
def _estimate_text_tokens(text: str) -> int: def _format_token_count(token_count: int) -> str:
"""估算单段文本的输入 token 数。""" """格式化 token 数量展示文本"""
normalized_text = text.strip() if token_count >= 10_000:
if not normalized_text: return f"{token_count / 1000:.1f}k"
return 0 return str(token_count)
cjk_char_count = sum(1 for char in normalized_text if "\u4e00" <= char <= "\u9fff")
latin_chunks = re.findall(r"[A-Za-z0-9_]+", normalized_text)
latin_token_count = sum(max(1, (len(chunk) + 3) // 4) for chunk in latin_chunks)
punctuation_count = len(re.findall(r"[^\w\s]", normalized_text))
whitespace_bonus = max(1, normalized_text.count("\n"))
return cjk_char_count + latin_token_count + punctuation_count + whitespace_bonus
@classmethod @classmethod
def _estimate_request_tokens(cls, messages: Sequence[Message]) -> int:
"""估算本轮请求消息的总输入 token 数。"""
total_tokens = 0
for message in messages:
total_tokens += 4
total_tokens += cls._estimate_text_tokens(str(message.role.value))
if message.tool_call_id:
total_tokens += cls._estimate_text_tokens(message.tool_call_id)
if message.tool_calls:
for tool_call in message.tool_calls:
total_tokens += cls._estimate_text_tokens(getattr(tool_call, "func_name", "") or "")
total_tokens += cls._estimate_text_tokens(
json.dumps(getattr(tool_call, "args", {}) or {}, ensure_ascii=False)
)
for part in message.parts:
if isinstance(part, TextMessagePart):
total_tokens += cls._estimate_text_tokens(part.text)
continue
if isinstance(part, ImageMessagePart):
total_tokens += max(256, len(part.image_base64) // 12)
return total_tokens
@staticmethod
def _build_prompt_stats_text( def _build_prompt_stats_text(
cls,
*, *,
selected_history_count: int, selected_history_count: int,
built_message_count: int, built_message_count: int,
input_token_count: int, prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> str: ) -> str:
"""构造本轮 prompt 的统计信息文本。""" """构造本轮 prompt 的统计信息文本。"""
if input_token_count >= 10_000:
input_token_text = f"{input_token_count / 1000:.1f}k"
else:
input_token_text = str(input_token_count)
return ( return (
f"已选上下文消息数={selected_history_count} " f"已选上下文消息数={selected_history_count} "
f"大模型消息数={built_message_count} " f"大模型消息数={built_message_count} "
f"估算输入Token={input_token_text}" f"实际输入Token={cls._format_token_count(prompt_tokens)} "
f"输出Token={cls._format_token_count(completion_tokens)} "
f"总Token={cls._format_token_count(total_tokens)}"
) )
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse: async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
@@ -764,13 +734,6 @@ class MaisakaChatLoopService:
await self.ensure_chat_prompt_loaded() await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self._select_llm_context_messages(chat_history) selected_history, selection_reason = self._select_llm_context_messages(chat_history)
built_messages = self._build_request_messages(selected_history) built_messages = self._build_request_messages(selected_history)
input_token_count = self._estimate_request_tokens(built_messages)
prompt_stats_text = self._build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
input_token_count=input_token_count,
)
display_subtitle = f"{selection_reason} | {prompt_stats_text}"
def message_factory(_client: BaseClient) -> List[Message]: def message_factory(_client: BaseClient) -> List[Message]:
"""返回当前轮次已经构建好的请求消息。 """返回当前轮次已经构建好的请求消息。
@@ -806,7 +769,7 @@ class MaisakaChatLoopService:
Panel( Panel(
Group(*ordered_panels), Group(*ordered_panels),
title="MaiSaka 大模型请求 - 对话单步", title="MaiSaka 大模型请求 - 对话单步",
subtitle=display_subtitle, subtitle=selection_reason,
border_style="cyan", border_style="cyan",
padding=(0, 1), padding=(0, 1),
) )
@@ -820,7 +783,6 @@ class MaisakaChatLoopService:
f"工具数={len(all_tools)} " f"工具数={len(all_tools)} "
f"启用打断={self._interrupt_flag is not None}" f"启用打断={self._interrupt_flag is not None}"
) )
logger.info(f"??Prompt??: {prompt_stats_text}")
generation_result = await self._llm_chat.generate_response_with_messages( generation_result = await self._llm_chat.generate_response_with_messages(
message_factory=message_factory, message_factory=message_factory,
options=LLMGenerationOptions( options=LLMGenerationOptions(
@@ -833,6 +795,15 @@ class MaisakaChatLoopService:
request_elapsed = perf_counter() - request_started_at request_elapsed = perf_counter() - request_started_at
logger.info(f"规划器请求完成,耗时={request_elapsed:.3f}") logger.info(f"规划器请求完成,耗时={request_elapsed:.3f}")
prompt_stats_text = self._build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
tool_call_summaries = [ tool_call_summaries = [
{ {
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)), "调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
"""
记忆系统工具函数
包含模糊查找、相似度计算等工具函数
"""
import json
import re
from datetime import datetime
from typing import Tuple
from typing import List
from json_repair import repair_json
from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
Args:
response: LLM返回的响应
Returns:
Tuple[List[str], List[str]]: (概念列表, 问题列表)
"""
try:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
parsed = json.loads(repaired_json)
# 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], []
concepts_raw = parsed.get("concepts", [])
questions_raw = parsed.get("questions", [])
# 确保是列表
if not isinstance(concepts_raw, list):
concepts_raw = []
if not isinstance(questions_raw, list):
questions_raw = []
# 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, questions
except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e:
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")

View File

@@ -1,36 +0,0 @@
"""
记忆检索工具模块
提供统一的工具注册和管理系统
"""
from .tool_registry import (
MemoryRetrievalTool,
MemoryRetrievalToolRegistry,
register_memory_retrieval_tool,
get_tool_registry,
)
# 导入所有工具的注册函数
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_words import register_tool as register_query_words
from .return_information import register_tool as register_return_information
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_words()
register_return_information()
# LPMM知识库检索工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()
__all__ = [
"MemoryRetrievalTool",
"MemoryRetrievalToolRegistry",
"register_memory_retrieval_tool",
"get_tool_registry",
"init_all_tools",
]

View File

@@ -1,75 +0,0 @@
"""
通过LPMM知识库查询信息 - 工具实现
"""
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
"""在LPMM知识库中查询相关信息
Args:
query: 查询关键词
Returns:
str: 查询结果
"""
try:
content = str(query).strip()
if not content:
return "查询关键词为空"
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用"
qa_manager = get_qa_manager()
if qa_manager is None:
logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info:
return f"你从LPMM知识库中找到以下信息\n{knowledge_info}"
return f"在LPMM知识库中未找到与“{content}”相关的信息"
except Exception as e:
logger.error(f"LPMM知识库查询失败: {e}")
return f"LPMM知识库查询失败{str(e)}"
def register_tool():
"""注册LPMM知识库查询工具"""
register_memory_retrieval_tool(
name="lpmm_search_knowledge",
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的问题使用一句疑问句提问例如什么是AI",
"required": True,
},
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
],
execute_func=query_lpmm_knowledge,
)

View File

@@ -1,78 +0,0 @@
"""
查询黑话/概念含义 - 工具实现
用于在记忆检索过程中主动查询未知词语或黑话的含义
"""
from src.common.logger import get_logger
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_words(chat_id: str, words: str) -> str:
"""查询词语或黑话的含义
Args:
chat_id: 聊天ID
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
Returns:
str: 查询结果,包含词语的含义解释
"""
try:
if not words or not words.strip():
return "未提供要查询的词语"
# 解析词语列表(支持逗号、空格等分隔符)
words_list = []
for separator in [",", "", " ", "\n", "\t"]:
if separator in words:
words_list = [w.strip() for w in words.split(separator) if w.strip()]
break
# 如果没有找到分隔符,整个字符串作为一个词语
if not words_list:
words_list = [words.strip()]
# 去重
unique_words = []
seen = set()
for word in words_list:
if word and word not in seen:
unique_words.append(word)
seen.add(word)
if not unique_words:
return "未提供有效的词语"
logger.info(f"查询词语含义: {unique_words}")
# 调用检索函数
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
if result:
return result
else:
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
except Exception as e:
logger.error(f"查询词语含义失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_words",
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
parameters=[
{
"name": "words",
"type": "string",
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS''YYDS,内卷,996'",
"required": True,
},
],
execute_func=query_words,
)

View File

@@ -1,42 +0,0 @@
"""
return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询
"""
from src.common.logger import get_logger
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def return_information(information: str) -> str:
"""返回总结信息并结束查询
Args:
information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。
Returns:
str: 确认信息
"""
if information and information.strip():
logger.info(f"返回总结信息: {information}")
return f"已确认返回信息: {information}"
else:
logger.info("未收集到相关信息,结束查询")
return "未收集到相关信息,查询结束"
def register_tool():
"""注册return_information工具"""
register_memory_retrieval_tool(
name="return_information",
description="当你决定结束查询时调用此工具。基于已收集的信息总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助在information参数中提供总结信息如果信息无关或没有帮助可以提供空字符串。",
parameters=[
{
"name": "information",
"type": "string",
"description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。",
"required": True,
},
],
execute_func=return_information,
)

View File

@@ -1,167 +0,0 @@
"""工具注册系统。
提供统一的工具注册和管理接口。
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""初始化工具。
Args:
name: 工具名称。
description: 工具描述。
parameters: 参数定义列表。
execute_func: 执行函数,必须是异步函数。
"""
self.name = name
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
for param in self.parameters:
param_name = param.get("name", "")
param_type = param.get("type", "string")
param_desc = param.get("description", "")
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs: Any) -> str:
"""执行工具。"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取规范化的工具定义。
Returns:
Dict[str, Any]: 统一工具定义字典。
"""
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"int": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
normalized_option = normalize_tool_option(
{
"name": self.name,
"description": self.description,
"parameters": legacy_parameters,
}
)
return {
"name": normalized_option.name,
"description": normalized_option.description,
"parameters_schema": normalized_option.parameters_schema,
}
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self) -> None:
"""初始化工具注册器。"""
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""
return [tool.get_tool_definition() for tool in self.tools.values()]
# 全局工具注册器实例
_tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""注册记忆检索工具的便捷函数。
Args:
name: 工具名称。
description: 工具描述。
parameters: 参数定义列表。
execute_func: 执行函数。
"""
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
_tool_registry.register_tool(tool)
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry