feat:显示实时占用上下文,移除旧记忆系统
This commit is contained in:
@@ -66,6 +66,9 @@ class LLMResponseResult(BaseDataModel):
|
||||
reasoning: str = field(default_factory=str)
|
||||
model_name: str = field(default_factory=str)
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -120,6 +123,9 @@ class LLMServiceResult(BaseDataModel):
|
||||
"response": self.completion.response,
|
||||
"reasoning": self.completion.reasoning,
|
||||
"model_name": self.completion.model_name,
|
||||
"prompt_tokens": self.completion.prompt_tokens,
|
||||
"completion_tokens": self.completion.completion_tokens,
|
||||
"total_tokens": self.completion.total_tokens,
|
||||
}
|
||||
if self.completion.tool_calls is not None:
|
||||
payload["tool_calls"] = [
|
||||
|
||||
@@ -34,6 +34,7 @@ from src.llm_models.model_client.base_client import (
|
||||
ClientRequest,
|
||||
EmbeddingRequest,
|
||||
ResponseRequest,
|
||||
UsageRecord,
|
||||
client_registry,
|
||||
)
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder
|
||||
@@ -137,6 +138,7 @@ class LLMOrchestrator:
|
||||
reasoning_content: str,
|
||||
model_name: str,
|
||||
tool_calls: List[ToolCall] | None,
|
||||
usage: UsageRecord | None = None,
|
||||
) -> LLMResponseResult:
|
||||
"""构建统一的文本响应结果。
|
||||
|
||||
@@ -154,6 +156,9 @@ class LLMOrchestrator:
|
||||
reasoning=reasoning_content,
|
||||
model_name=model_name,
|
||||
tool_calls=tool_calls,
|
||||
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
|
||||
completion_tokens=usage.completion_tokens if usage is not None else 0,
|
||||
total_tokens=usage.total_tokens if usage is not None else 0,
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
@@ -215,7 +220,13 @@ class LLMOrchestrator:
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time_cost,
|
||||
)
|
||||
return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls)
|
||||
return self._build_generation_result(
|
||||
content,
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
tool_calls,
|
||||
response.usage,
|
||||
)
|
||||
|
||||
async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||
"""为语音生成转录响应。
|
||||
@@ -298,7 +309,13 @@ class LLMOrchestrator:
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time.time() - start_time,
|
||||
)
|
||||
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls)
|
||||
return self._build_generation_result(
|
||||
content or "",
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
tool_calls,
|
||||
response.usage,
|
||||
)
|
||||
|
||||
async def generate_response_with_message_async(
|
||||
self,
|
||||
@@ -364,7 +381,13 @@ class LLMOrchestrator:
|
||||
endpoint="/chat/completions",
|
||||
time_cost=time_cost,
|
||||
)
|
||||
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls)
|
||||
return self._build_generation_result(
|
||||
content or "",
|
||||
reasoning_content,
|
||||
model_info.name,
|
||||
tool_calls,
|
||||
response.usage,
|
||||
)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult:
|
||||
"""获取嵌入向量。
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Sequence
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
@@ -28,7 +27,7 @@ from src.config.config import global_config
|
||||
from src.core.tooling import ToolRegistry, ToolSpec
|
||||
from src.know_u.knowledge import extract_category_ids_from_result
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
@@ -697,58 +696,29 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _estimate_text_tokens(text: str) -> int:
|
||||
"""估算单段文本的输入 token 数。"""
|
||||
normalized_text = text.strip()
|
||||
if not normalized_text:
|
||||
return 0
|
||||
|
||||
cjk_char_count = sum(1 for char in normalized_text if "\u4e00" <= char <= "\u9fff")
|
||||
latin_chunks = re.findall(r"[A-Za-z0-9_]+", normalized_text)
|
||||
latin_token_count = sum(max(1, (len(chunk) + 3) // 4) for chunk in latin_chunks)
|
||||
punctuation_count = len(re.findall(r"[^\w\s]", normalized_text))
|
||||
whitespace_bonus = max(1, normalized_text.count("\n"))
|
||||
return cjk_char_count + latin_token_count + punctuation_count + whitespace_bonus
|
||||
def _format_token_count(token_count: int) -> str:
|
||||
"""格式化 token 数量展示文本。"""
|
||||
if token_count >= 10_000:
|
||||
return f"{token_count / 1000:.1f}k"
|
||||
return str(token_count)
|
||||
|
||||
@classmethod
|
||||
def _estimate_request_tokens(cls, messages: Sequence[Message]) -> int:
|
||||
"""估算本轮请求消息的总输入 token 数。"""
|
||||
total_tokens = 0
|
||||
for message in messages:
|
||||
total_tokens += 4
|
||||
total_tokens += cls._estimate_text_tokens(str(message.role.value))
|
||||
if message.tool_call_id:
|
||||
total_tokens += cls._estimate_text_tokens(message.tool_call_id)
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
total_tokens += cls._estimate_text_tokens(getattr(tool_call, "func_name", "") or "")
|
||||
total_tokens += cls._estimate_text_tokens(
|
||||
json.dumps(getattr(tool_call, "args", {}) or {}, ensure_ascii=False)
|
||||
)
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
total_tokens += cls._estimate_text_tokens(part.text)
|
||||
continue
|
||||
if isinstance(part, ImageMessagePart):
|
||||
total_tokens += max(256, len(part.image_base64) // 12)
|
||||
return total_tokens
|
||||
|
||||
@staticmethod
|
||||
def _build_prompt_stats_text(
|
||||
cls,
|
||||
*,
|
||||
selected_history_count: int,
|
||||
built_message_count: int,
|
||||
input_token_count: int,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
) -> str:
|
||||
"""构造本轮 prompt 的统计信息文本。"""
|
||||
if input_token_count >= 10_000:
|
||||
input_token_text = f"{input_token_count / 1000:.1f}k"
|
||||
else:
|
||||
input_token_text = str(input_token_count)
|
||||
return (
|
||||
f"已选上下文消息数={selected_history_count} "
|
||||
f"大模型消息数={built_message_count} "
|
||||
f"估算输入Token={input_token_text}"
|
||||
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
|
||||
f"输出Token={cls._format_token_count(completion_tokens)} "
|
||||
f"总Token={cls._format_token_count(total_tokens)}"
|
||||
)
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
|
||||
@@ -764,13 +734,6 @@ class MaisakaChatLoopService:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
|
||||
built_messages = self._build_request_messages(selected_history)
|
||||
input_token_count = self._estimate_request_tokens(built_messages)
|
||||
prompt_stats_text = self._build_prompt_stats_text(
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
input_token_count=input_token_count,
|
||||
)
|
||||
display_subtitle = f"{selection_reason} | {prompt_stats_text}"
|
||||
|
||||
def message_factory(_client: BaseClient) -> List[Message]:
|
||||
"""返回当前轮次已经构建好的请求消息。
|
||||
@@ -806,7 +769,7 @@ class MaisakaChatLoopService:
|
||||
Panel(
|
||||
Group(*ordered_panels),
|
||||
title="MaiSaka 大模型请求 - 对话单步",
|
||||
subtitle=display_subtitle,
|
||||
subtitle=selection_reason,
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
@@ -820,7 +783,6 @@ class MaisakaChatLoopService:
|
||||
f"工具数={len(all_tools)} "
|
||||
f"启用打断={self._interrupt_flag is not None}"
|
||||
)
|
||||
logger.info(f"??Prompt??: {prompt_stats_text}")
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
@@ -833,6 +795,15 @@ class MaisakaChatLoopService:
|
||||
request_elapsed = perf_counter() - request_started_at
|
||||
logger.info(f"规划器请求完成,耗时={request_elapsed:.3f} 秒")
|
||||
|
||||
prompt_stats_text = self._build_prompt_stats_text(
|
||||
selected_history_count=len(selected_history),
|
||||
built_message_count=len(built_messages),
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
completion_tokens=generation_result.completion_tokens,
|
||||
total_tokens=generation_result.total_tokens,
|
||||
)
|
||||
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
|
||||
|
||||
tool_call_summaries = [
|
||||
{
|
||||
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,98 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
记忆系统工具函数
|
||||
包含模糊查找、相似度计算等工具函数
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from typing import List
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("memory_utils")
|
||||
|
||||
|
||||
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
|
||||
"""解析问题JSON,返回概念列表和问题列表
|
||||
|
||||
Args:
|
||||
response: LLM返回的响应
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], List[str]]: (概念列表, 问题列表)
|
||||
"""
|
||||
try:
|
||||
# 尝试提取JSON(可能包含在```json代码块中)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
json_str = matches[0]
|
||||
else:
|
||||
# 尝试直接解析整个响应
|
||||
json_str = response.strip()
|
||||
|
||||
# 修复可能的JSON错误
|
||||
repaired_json = repair_json(json_str)
|
||||
|
||||
# 解析JSON
|
||||
parsed = json.loads(repaired_json)
|
||||
|
||||
# 只支持新格式:包含concepts和questions的对象
|
||||
if not isinstance(parsed, dict):
|
||||
logger.warning(f"解析的JSON不是对象格式: {parsed}")
|
||||
return [], []
|
||||
|
||||
concepts_raw = parsed.get("concepts", [])
|
||||
questions_raw = parsed.get("questions", [])
|
||||
|
||||
# 确保是列表
|
||||
if not isinstance(concepts_raw, list):
|
||||
concepts_raw = []
|
||||
if not isinstance(questions_raw, list):
|
||||
questions_raw = []
|
||||
|
||||
# 确保所有元素都是字符串
|
||||
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
|
||||
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
|
||||
|
||||
return concepts, questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
|
||||
return [], []
|
||||
|
||||
|
||||
def parse_datetime_to_timestamp(value: str) -> float:
|
||||
"""
|
||||
接受多种常见格式并转换为时间戳(秒)
|
||||
支持示例:
|
||||
- 2025-09-29
|
||||
- 2025-09-29 00:00:00
|
||||
- 2025/09/29 00:00
|
||||
- 2025-09-29T00:00:00
|
||||
"""
|
||||
value = value.strip()
|
||||
fmts = [
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
"%Y-%m-%d %H:%M",
|
||||
"%Y/%m/%d %H:%M:%S",
|
||||
"%Y/%m/%d %H:%M",
|
||||
"%Y-%m-%d",
|
||||
"%Y/%m/%d",
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
"%Y-%m-%dT%H:%M",
|
||||
]
|
||||
last_err = None
|
||||
for fmt in fmts:
|
||||
try:
|
||||
dt = datetime.strptime(value, fmt)
|
||||
return dt.timestamp()
|
||||
except Exception as e:
|
||||
last_err = e
|
||||
raise ValueError(f"无法解析时间: {value} ({last_err})")
|
||||
@@ -1,36 +0,0 @@
|
||||
"""
|
||||
记忆检索工具模块
|
||||
提供统一的工具注册和管理系统
|
||||
"""
|
||||
|
||||
from .tool_registry import (
|
||||
MemoryRetrievalTool,
|
||||
MemoryRetrievalToolRegistry,
|
||||
register_memory_retrieval_tool,
|
||||
get_tool_registry,
|
||||
)
|
||||
|
||||
# 导入所有工具的注册函数
|
||||
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
|
||||
from .query_words import register_tool as register_query_words
|
||||
from .return_information import register_tool as register_return_information
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def init_all_tools():
|
||||
"""初始化并注册所有记忆检索工具"""
|
||||
register_query_words()
|
||||
register_return_information()
|
||||
|
||||
# LPMM知识库检索工具
|
||||
if global_config.lpmm_knowledge.lpmm_mode == "agent":
|
||||
register_lpmm_knowledge()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryRetrievalTool",
|
||||
"MemoryRetrievalToolRegistry",
|
||||
"register_memory_retrieval_tool",
|
||||
"get_tool_registry",
|
||||
"init_all_tools",
|
||||
]
|
||||
@@ -1,75 +0,0 @@
|
||||
"""
|
||||
通过LPMM知识库查询信息 - 工具实现
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.knowledge import get_qa_manager
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
|
||||
"""在LPMM知识库中查询相关信息
|
||||
|
||||
Args:
|
||||
query: 查询关键词
|
||||
|
||||
Returns:
|
||||
str: 查询结果
|
||||
"""
|
||||
try:
|
||||
content = str(query).strip()
|
||||
if not content:
|
||||
return "查询关键词为空"
|
||||
|
||||
try:
|
||||
limit_value = int(limit)
|
||||
except (TypeError, ValueError):
|
||||
limit_value = 5
|
||||
limit_value = max(1, limit_value)
|
||||
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
logger.debug("LPMM知识库未启用")
|
||||
return "LPMM知识库未启用"
|
||||
|
||||
qa_manager = get_qa_manager()
|
||||
if qa_manager is None:
|
||||
logger.debug("LPMM知识库未初始化,跳过查询")
|
||||
return "LPMM知识库未初始化"
|
||||
|
||||
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
|
||||
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
|
||||
|
||||
if knowledge_info:
|
||||
return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}"
|
||||
|
||||
return f"在LPMM知识库中未找到与“{content}”相关的信息"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LPMM知识库查询失败: {e}")
|
||||
return f"LPMM知识库查询失败:{str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册LPMM知识库查询工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="lpmm_search_knowledge",
|
||||
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
|
||||
parameters=[
|
||||
{
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "limit",
|
||||
"type": "integer",
|
||||
"description": "希望返回的相关知识条数,默认为5",
|
||||
"required": False,
|
||||
},
|
||||
],
|
||||
execute_func=query_lpmm_knowledge,
|
||||
)
|
||||
@@ -1,78 +0,0 @@
|
||||
"""
|
||||
查询黑话/概念含义 - 工具实现
|
||||
用于在记忆检索过程中主动查询未知词语或黑话的含义
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def query_words(chat_id: str, words: str) -> str:
|
||||
"""查询词语或黑话的含义
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
|
||||
|
||||
Returns:
|
||||
str: 查询结果,包含词语的含义解释
|
||||
"""
|
||||
try:
|
||||
if not words or not words.strip():
|
||||
return "未提供要查询的词语"
|
||||
|
||||
# 解析词语列表(支持逗号、空格等分隔符)
|
||||
words_list = []
|
||||
for separator in [",", ",", " ", "\n", "\t"]:
|
||||
if separator in words:
|
||||
words_list = [w.strip() for w in words.split(separator) if w.strip()]
|
||||
break
|
||||
|
||||
# 如果没有找到分隔符,整个字符串作为一个词语
|
||||
if not words_list:
|
||||
words_list = [words.strip()]
|
||||
|
||||
# 去重
|
||||
unique_words = []
|
||||
seen = set()
|
||||
for word in words_list:
|
||||
if word and word not in seen:
|
||||
unique_words.append(word)
|
||||
seen.add(word)
|
||||
|
||||
if not unique_words:
|
||||
return "未提供有效的词语"
|
||||
|
||||
logger.info(f"查询词语含义: {unique_words}")
|
||||
|
||||
# 调用检索函数
|
||||
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
|
||||
|
||||
if result:
|
||||
return result
|
||||
else:
|
||||
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询词语含义失败: {e}")
|
||||
return f"查询失败: {str(e)}"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="query_words",
|
||||
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "words",
|
||||
"type": "string",
|
||||
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS' 或 'YYDS,内卷,996')",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=query_words,
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询
|
||||
"""
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .tool_registry import register_memory_retrieval_tool
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
async def return_information(information: str) -> str:
|
||||
"""返回总结信息并结束查询
|
||||
|
||||
Args:
|
||||
information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。
|
||||
|
||||
Returns:
|
||||
str: 确认信息
|
||||
"""
|
||||
if information and information.strip():
|
||||
logger.info(f"返回总结信息: {information}")
|
||||
return f"已确认返回信息: {information}"
|
||||
else:
|
||||
logger.info("未收集到相关信息,结束查询")
|
||||
return "未收集到相关信息,查询结束"
|
||||
|
||||
|
||||
def register_tool():
|
||||
"""注册return_information工具"""
|
||||
register_memory_retrieval_tool(
|
||||
name="return_information",
|
||||
description="当你决定结束查询时,调用此工具。基于已收集的信息,总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助,在information参数中提供总结信息;如果信息无关或没有帮助,可以提供空字符串。",
|
||||
parameters=[
|
||||
{
|
||||
"name": "information",
|
||||
"type": "string",
|
||||
"description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。",
|
||||
"required": True,
|
||||
},
|
||||
],
|
||||
execute_func=return_information,
|
||||
)
|
||||
@@ -1,167 +0,0 @@
|
||||
"""工具注册系统。
|
||||
|
||||
提供统一的工具注册和管理接口。
|
||||
"""
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
|
||||
|
||||
logger = get_logger("memory_retrieval_tools")
|
||||
|
||||
|
||||
class MemoryRetrievalTool:
|
||||
"""记忆检索工具基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""初始化工具。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数,必须是异步函数。
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.parameters = parameters
|
||||
self.execute_func = execute_func
|
||||
|
||||
def get_tool_description(self) -> str:
|
||||
"""获取工具的文本描述,用于prompt"""
|
||||
param_descriptions = []
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type = param.get("type", "string")
|
||||
param_desc = param.get("description", "")
|
||||
required = param.get("required", True)
|
||||
required_str = "必填" if required else "可选"
|
||||
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
|
||||
|
||||
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
|
||||
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""执行工具。"""
|
||||
return await self.execute_func(**kwargs)
|
||||
|
||||
def get_tool_definition(self) -> Dict[str, Any]:
|
||||
"""获取规范化的工具定义。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 统一工具定义字典。
|
||||
"""
|
||||
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
|
||||
|
||||
for param in self.parameters:
|
||||
param_name = param.get("name", "")
|
||||
param_type_str = param.get("type", "string").lower()
|
||||
param_desc = param.get("description", "")
|
||||
is_required = param.get("required", False)
|
||||
enum_values = param.get("enum", None)
|
||||
|
||||
# 转换类型字符串到ToolParamType
|
||||
type_mapping = {
|
||||
"string": ToolParamType.STRING,
|
||||
"integer": ToolParamType.INTEGER,
|
||||
"int": ToolParamType.INTEGER,
|
||||
"float": ToolParamType.FLOAT,
|
||||
"boolean": ToolParamType.BOOLEAN,
|
||||
"bool": ToolParamType.BOOLEAN,
|
||||
}
|
||||
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
|
||||
|
||||
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
|
||||
|
||||
normalized_option = normalize_tool_option(
|
||||
{
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": legacy_parameters,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"name": normalized_option.name,
|
||||
"description": normalized_option.description,
|
||||
"parameters_schema": normalized_option.parameters_schema,
|
||||
}
|
||||
|
||||
|
||||
class MemoryRetrievalToolRegistry:
|
||||
"""工具注册器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化工具注册器。"""
|
||||
self.tools: Dict[str, MemoryRetrievalTool] = {}
|
||||
|
||||
def register_tool(self, tool: MemoryRetrievalTool) -> None:
|
||||
"""注册工具"""
|
||||
if tool.name in self.tools:
|
||||
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
|
||||
return
|
||||
self.tools[tool.name] = tool
|
||||
logger.info(f"注册记忆检索工具: {tool.name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
|
||||
"""获取工具"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
|
||||
"""获取所有工具"""
|
||||
return self.tools.copy()
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""获取所有工具的描述,用于prompt"""
|
||||
descriptions = []
|
||||
for i, tool in enumerate(self.tools.values(), 1):
|
||||
descriptions.append(f"{i}. {tool.get_tool_description()}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def get_action_types_list(self) -> str:
|
||||
"""获取所有动作类型的列表,用于prompt(已废弃,保留用于兼容)"""
|
||||
action_types = [tool.name for tool in self.tools.values()]
|
||||
action_types.append("final_answer")
|
||||
action_types.append("no_answer")
|
||||
return " 或 ".join([f'"{at}"' for at in action_types])
|
||||
|
||||
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有工具的定义列表,用于LLM function calling
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
|
||||
"""
|
||||
return [tool.get_tool_definition() for tool in self.tools.values()]
|
||||
|
||||
|
||||
# 全局工具注册器实例
|
||||
_tool_registry = MemoryRetrievalToolRegistry()
|
||||
|
||||
|
||||
def register_memory_retrieval_tool(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: List[Dict[str, Any]],
|
||||
execute_func: Callable[..., Awaitable[str]],
|
||||
) -> None:
|
||||
"""注册记忆检索工具的便捷函数。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
description: 工具描述。
|
||||
parameters: 参数定义列表。
|
||||
execute_func: 执行函数。
|
||||
"""
|
||||
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
|
||||
_tool_registry.register_tool(tool)
|
||||
|
||||
|
||||
def get_tool_registry() -> MemoryRetrievalToolRegistry:
|
||||
"""获取工具注册器实例"""
|
||||
return _tool_registry
|
||||
Reference in New Issue
Block a user