feat:显示实时占用上下文,移除旧记忆系统

This commit is contained in:
SengokuCola
2026-04-01 13:18:17 +08:00
parent 503a257d66
commit d713aa9576
11 changed files with 55 additions and 2663 deletions

View File

@@ -66,6 +66,9 @@ class LLMResponseResult(BaseDataModel):
reasoning: str = field(default_factory=str)
model_name: str = field(default_factory=str)
tool_calls: List[ToolCall] | None = None
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
@dataclass(slots=True)
@@ -120,6 +123,9 @@ class LLMServiceResult(BaseDataModel):
"response": self.completion.response,
"reasoning": self.completion.reasoning,
"model_name": self.completion.model_name,
"prompt_tokens": self.completion.prompt_tokens,
"completion_tokens": self.completion.completion_tokens,
"total_tokens": self.completion.total_tokens,
}
if self.completion.tool_calls is not None:
payload["tool_calls"] = [

View File

@@ -34,6 +34,7 @@ from src.llm_models.model_client.base_client import (
ClientRequest,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
client_registry,
)
from src.llm_models.payload_content.message import Message, MessageBuilder
@@ -137,6 +138,7 @@ class LLMOrchestrator:
reasoning_content: str,
model_name: str,
tool_calls: List[ToolCall] | None,
usage: UsageRecord | None = None,
) -> LLMResponseResult:
"""构建统一的文本响应结果。
@@ -154,6 +156,9 @@ class LLMOrchestrator:
reasoning=reasoning_content,
model_name=model_name,
tool_calls=tool_calls,
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
completion_tokens=usage.completion_tokens if usage is not None else 0,
total_tokens=usage.total_tokens if usage is not None else 0,
)
async def generate_response_for_image(
@@ -215,7 +220,13 @@ class LLMOrchestrator:
endpoint="/chat/completions",
time_cost=time_cost,
)
return self._build_generation_result(content, reasoning_content, model_info.name, tool_calls)
return self._build_generation_result(
content,
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def generate_response_for_voice(self, voice_base64: str) -> LLMAudioTranscriptionResult:
"""为语音生成转录响应。
@@ -298,7 +309,13 @@ class LLMOrchestrator:
endpoint="/chat/completions",
time_cost=time.time() - start_time,
)
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls)
return self._build_generation_result(
content or "",
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def generate_response_with_message_async(
self,
@@ -364,7 +381,13 @@ class LLMOrchestrator:
endpoint="/chat/completions",
time_cost=time_cost,
)
return self._build_generation_result(content or "", reasoning_content, model_info.name, tool_calls)
return self._build_generation_result(
content or "",
reasoning_content,
model_info.name,
tool_calls,
response.usage,
)
async def get_embedding(self, embedding_input: str) -> LLMEmbeddingResult:
"""获取嵌入向量。

View File

@@ -10,7 +10,6 @@ from typing import Any, Dict, List, Optional, Sequence
import asyncio
import json
import random
import re
from PIL import Image as PILImage
from pydantic import BaseModel, Field as PydanticField
@@ -28,7 +27,7 @@ from src.config.config import global_config
from src.core.tooling import ToolRegistry, ToolSpec
from src.know_u.knowledge import extract_category_ids_from_result
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.services.llm_service import LLMServiceClient
@@ -697,58 +696,29 @@ class MaisakaChatLoopService:
)
@staticmethod
def _estimate_text_tokens(text: str) -> int:
"""估算单段文本的输入 token 数。"""
normalized_text = text.strip()
if not normalized_text:
return 0
cjk_char_count = sum(1 for char in normalized_text if "\u4e00" <= char <= "\u9fff")
latin_chunks = re.findall(r"[A-Za-z0-9_]+", normalized_text)
latin_token_count = sum(max(1, (len(chunk) + 3) // 4) for chunk in latin_chunks)
punctuation_count = len(re.findall(r"[^\w\s]", normalized_text))
whitespace_bonus = max(1, normalized_text.count("\n"))
return cjk_char_count + latin_token_count + punctuation_count + whitespace_bonus
def _format_token_count(token_count: int) -> str:
"""格式化 token 数量展示文本"""
if token_count >= 10_000:
return f"{token_count / 1000:.1f}k"
return str(token_count)
@classmethod
def _estimate_request_tokens(cls, messages: Sequence[Message]) -> int:
"""估算本轮请求消息的总输入 token 数。"""
total_tokens = 0
for message in messages:
total_tokens += 4
total_tokens += cls._estimate_text_tokens(str(message.role.value))
if message.tool_call_id:
total_tokens += cls._estimate_text_tokens(message.tool_call_id)
if message.tool_calls:
for tool_call in message.tool_calls:
total_tokens += cls._estimate_text_tokens(getattr(tool_call, "func_name", "") or "")
total_tokens += cls._estimate_text_tokens(
json.dumps(getattr(tool_call, "args", {}) or {}, ensure_ascii=False)
)
for part in message.parts:
if isinstance(part, TextMessagePart):
total_tokens += cls._estimate_text_tokens(part.text)
continue
if isinstance(part, ImageMessagePart):
total_tokens += max(256, len(part.image_base64) // 12)
return total_tokens
@staticmethod
def _build_prompt_stats_text(
cls,
*,
selected_history_count: int,
built_message_count: int,
input_token_count: int,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
) -> str:
"""构造本轮 prompt 的统计信息文本。"""
if input_token_count >= 10_000:
input_token_text = f"{input_token_count / 1000:.1f}k"
else:
input_token_text = str(input_token_count)
return (
f"已选上下文消息数={selected_history_count} "
f"大模型消息数={built_message_count} "
f"估算输入Token={input_token_text}"
f"实际输入Token={cls._format_token_count(prompt_tokens)} "
f"输出Token={cls._format_token_count(completion_tokens)} "
f"总Token={cls._format_token_count(total_tokens)}"
)
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
@@ -764,13 +734,6 @@ class MaisakaChatLoopService:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
built_messages = self._build_request_messages(selected_history)
input_token_count = self._estimate_request_tokens(built_messages)
prompt_stats_text = self._build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
input_token_count=input_token_count,
)
display_subtitle = f"{selection_reason} | {prompt_stats_text}"
def message_factory(_client: BaseClient) -> List[Message]:
"""返回当前轮次已经构建好的请求消息。
@@ -806,7 +769,7 @@ class MaisakaChatLoopService:
Panel(
Group(*ordered_panels),
title="MaiSaka 大模型请求 - 对话单步",
subtitle=display_subtitle,
subtitle=selection_reason,
border_style="cyan",
padding=(0, 1),
)
@@ -820,7 +783,6 @@ class MaisakaChatLoopService:
f"工具数={len(all_tools)} "
f"启用打断={self._interrupt_flag is not None}"
)
logger.info(f"??Prompt??: {prompt_stats_text}")
generation_result = await self._llm_chat.generate_response_with_messages(
message_factory=message_factory,
options=LLMGenerationOptions(
@@ -833,6 +795,15 @@ class MaisakaChatLoopService:
request_elapsed = perf_counter() - request_started_at
logger.info(f"规划器请求完成,耗时={request_elapsed:.3f}")
prompt_stats_text = self._build_prompt_stats_text(
selected_history_count=len(selected_history),
built_message_count=len(built_messages),
prompt_tokens=generation_result.prompt_tokens,
completion_tokens=generation_result.completion_tokens,
total_tokens=generation_result.total_tokens,
)
logger.info(f"本轮Prompt统计: {prompt_stats_text}")
tool_call_summaries = [
{
"调用编号": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
"""
记忆系统工具函数
包含模糊查找、相似度计算等工具函数
"""
import json
import re
from datetime import datetime
from typing import Tuple
from typing import List
from json_repair import repair_json
from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
Args:
response: LLM返回的响应
Returns:
Tuple[List[str], List[str]]: (概念列表, 问题列表)
"""
try:
# 尝试提取JSON可能包含在```json代码块中
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
json_str = matches[0]
else:
# 尝试直接解析整个响应
json_str = response.strip()
# 修复可能的JSON错误
repaired_json = repair_json(json_str)
# 解析JSON
parsed = json.loads(repaired_json)
# 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], []
concepts_raw = parsed.get("concepts", [])
questions_raw = parsed.get("questions", [])
# 确保是列表
if not isinstance(concepts_raw, list):
concepts_raw = []
if not isinstance(questions_raw, list):
questions_raw = []
# 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, questions
except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳(秒)
支持示例:
- 2025-09-29
- 2025-09-29 00:00:00
- 2025/09/29 00:00
- 2025-09-29T00:00:00
"""
value = value.strip()
fmts = [
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d",
"%Y/%m/%d",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
]
last_err = None
for fmt in fmts:
try:
dt = datetime.strptime(value, fmt)
return dt.timestamp()
except Exception as e:
last_err = e
raise ValueError(f"无法解析时间: {value} ({last_err})")

View File

@@ -1,36 +0,0 @@
"""
记忆检索工具模块
提供统一的工具注册和管理系统
"""
from .tool_registry import (
MemoryRetrievalTool,
MemoryRetrievalToolRegistry,
register_memory_retrieval_tool,
get_tool_registry,
)
# 导入所有工具的注册函数
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_words import register_tool as register_query_words
from .return_information import register_tool as register_return_information
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_words()
register_return_information()
# LPMM知识库检索工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()
__all__ = [
"MemoryRetrievalTool",
"MemoryRetrievalToolRegistry",
"register_memory_retrieval_tool",
"get_tool_registry",
"init_all_tools",
]

View File

@@ -1,75 +0,0 @@
"""
通过LPMM知识库查询信息 - 工具实现
"""
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
"""在LPMM知识库中查询相关信息
Args:
query: 查询关键词
Returns:
str: 查询结果
"""
try:
content = str(query).strip()
if not content:
return "查询关键词为空"
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用"
qa_manager = get_qa_manager()
if qa_manager is None:
logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info:
return f"你从LPMM知识库中找到以下信息\n{knowledge_info}"
return f"在LPMM知识库中未找到与“{content}”相关的信息"
except Exception as e:
logger.error(f"LPMM知识库查询失败: {e}")
return f"LPMM知识库查询失败{str(e)}"
def register_tool():
"""注册LPMM知识库查询工具"""
register_memory_retrieval_tool(
name="lpmm_search_knowledge",
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的问题使用一句疑问句提问例如什么是AI",
"required": True,
},
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
],
execute_func=query_lpmm_knowledge,
)

View File

@@ -1,78 +0,0 @@
"""
查询黑话/概念含义 - 工具实现
用于在记忆检索过程中主动查询未知词语或黑话的含义
"""
from src.common.logger import get_logger
from src.learners.jargon_explainer_old import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_words(chat_id: str, words: str) -> str:
"""查询词语或黑话的含义
Args:
chat_id: 聊天ID
words: 要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔)
Returns:
str: 查询结果,包含词语的含义解释
"""
try:
if not words or not words.strip():
return "未提供要查询的词语"
# 解析词语列表(支持逗号、空格等分隔符)
words_list = []
for separator in [",", "", " ", "\n", "\t"]:
if separator in words:
words_list = [w.strip() for w in words.split(separator) if w.strip()]
break
# 如果没有找到分隔符,整个字符串作为一个词语
if not words_list:
words_list = [words.strip()]
# 去重
unique_words = []
seen = set()
for word in words_list:
if word and word not in seen:
unique_words.append(word)
seen.add(word)
if not unique_words:
return "未提供有效的词语"
logger.info(f"查询词语含义: {unique_words}")
# 调用检索函数
result = await retrieve_concepts_with_jargon(unique_words, chat_id)
if result:
return result
else:
return f"未找到词语 '{', '.join(unique_words)}' 的含义或黑话解释"
except Exception as e:
logger.error(f"查询词语含义失败: {e}")
return f"查询失败: {str(e)}"
def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_words",
description="查询词语或黑话的含义。当遇到不熟悉的词语、缩写、黑话或网络用语时,可以使用此工具查询其含义。支持查询单个或多个词语(用逗号、空格等分隔)。",
parameters=[
{
"name": "words",
"type": "string",
"description": "要查询的词语,可以是单个词语或多个词语(用逗号、空格等分隔,如:'YYDS''YYDS,内卷,996'",
"required": True,
},
],
execute_func=query_words,
)

View File

@@ -1,42 +0,0 @@
"""
return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询
"""
from src.common.logger import get_logger
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def return_information(information: str) -> str:
"""返回总结信息并结束查询
Args:
information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。
Returns:
str: 确认信息
"""
if information and information.strip():
logger.info(f"返回总结信息: {information}")
return f"已确认返回信息: {information}"
else:
logger.info("未收集到相关信息,结束查询")
return "未收集到相关信息,查询结束"
def register_tool():
"""注册return_information工具"""
register_memory_retrieval_tool(
name="return_information",
description="当你决定结束查询时调用此工具。基于已收集的信息总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助在information参数中提供总结信息如果信息无关或没有帮助可以提供空字符串。",
parameters=[
{
"name": "information",
"type": "string",
"description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。",
"required": True,
},
],
execute_func=return_information,
)

View File

@@ -1,167 +0,0 @@
"""工具注册系统。
提供统一的工具注册和管理接口。
"""
from typing import Any, Awaitable, Callable, Dict, List, Optional
from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolParamType, normalize_tool_option
logger = get_logger("memory_retrieval_tools")
class MemoryRetrievalTool:
"""记忆检索工具基类"""
def __init__(
self,
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""初始化工具。
Args:
name: 工具名称。
description: 工具描述。
parameters: 参数定义列表。
execute_func: 执行函数,必须是异步函数。
"""
self.name = name
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_description(self) -> str:
"""获取工具的文本描述用于prompt"""
param_descriptions = []
for param in self.parameters:
param_name = param.get("name", "")
param_type = param.get("type", "string")
param_desc = param.get("description", "")
required = param.get("required", True)
required_str = "必填" if required else "可选"
param_descriptions.append(f" - {param_name} ({param_type}, {required_str}): {param_desc}")
params_str = "\n".join(param_descriptions) if param_descriptions else " 无参数"
return f"{self.name}({', '.join([p['name'] for p in self.parameters])}): {self.description}\n{params_str}"
async def execute(self, **kwargs: Any) -> str:
"""执行工具。"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取规范化的工具定义。
Returns:
Dict[str, Any]: 统一工具定义字典。
"""
legacy_parameters: list[tuple[str, ToolParamType, str, bool, list[str] | None]] = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"int": ToolParamType.INTEGER,
"float": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
legacy_parameters.append((param_name, param_type, param_desc, is_required, enum_values))
normalized_option = normalize_tool_option(
{
"name": self.name,
"description": self.description,
"parameters": legacy_parameters,
}
)
return {
"name": normalized_option.name,
"description": normalized_option.description,
"parameters_schema": normalized_option.parameters_schema,
}
class MemoryRetrievalToolRegistry:
"""工具注册器"""
def __init__(self) -> None:
"""初始化工具注册器。"""
self.tools: Dict[str, MemoryRetrievalTool] = {}
def register_tool(self, tool: MemoryRetrievalTool) -> None:
"""注册工具"""
if tool.name in self.tools:
logger.debug(f"记忆检索工具 {tool.name} 已存在,跳过重复注册")
return
self.tools[tool.name] = tool
logger.info(f"注册记忆检索工具: {tool.name}")
def get_tool(self, name: str) -> Optional[MemoryRetrievalTool]:
"""获取工具"""
return self.tools.get(name)
def get_all_tools(self) -> Dict[str, MemoryRetrievalTool]:
"""获取所有工具"""
return self.tools.copy()
def get_tools_description(self) -> str:
"""获取所有工具的描述用于prompt"""
descriptions = []
for i, tool in enumerate(self.tools.values(), 1):
descriptions.append(f"{i}. {tool.get_tool_description()}")
return "\n".join(descriptions)
def get_action_types_list(self) -> str:
"""获取所有动作类型的列表用于prompt已废弃保留用于兼容"""
action_types = [tool.name for tool in self.tools.values()]
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表,每个元素是一个工具定义字典
"""
return [tool.get_tool_definition() for tool in self.tools.values()]
# 全局工具注册器实例
_tool_registry = MemoryRetrievalToolRegistry()
def register_memory_retrieval_tool(
name: str,
description: str,
parameters: List[Dict[str, Any]],
execute_func: Callable[..., Awaitable[str]],
) -> None:
"""注册记忆检索工具的便捷函数。
Args:
name: 工具名称。
description: 工具描述。
parameters: 参数定义列表。
execute_func: 执行函数。
"""
tool = MemoryRetrievalTool(name, description, parameters, execute_func)
_tool_registry.register_tool(tool)
def get_tool_registry() -> MemoryRetrievalToolRegistry:
"""获取工具注册器实例"""
return _tool_registry