fix:工具调用存储问题

This commit is contained in:
SengokuCola
2026-04-13 19:54:38 +08:00
parent 8f2337fe99
commit 2471a2c4a4
3 changed files with 18 additions and 13 deletions

View File

@@ -6,6 +6,7 @@ import json
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from uuid import uuid4
from json_repair import repair_json from json_repair import repair_json
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
@@ -119,6 +120,13 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
"""OpenAI 非流式响应解析函数类型。""" """OpenAI 非流式响应解析函数类型。"""
def _build_fallback_tool_call_id(prefix: str) -> str:
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
normalized_prefix = str(prefix).strip() or "tool_call"
return f"{normalized_prefix}_{uuid4().hex}"
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
"""将配置中的推理解析模式收敛为枚举值。 """将配置中的推理解析模式收敛为枚举值。
@@ -609,7 +617,7 @@ def _extract_xml_tool_calls(
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {} arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
call_id=f"xml_tool_call_{len(tool_calls) + 1}", call_id=_build_fallback_tool_call_id("xml_tool_call"),
func_name=function_name, func_name=function_name,
args=arguments, args=arguments,
) )
@@ -855,7 +863,7 @@ class _OpenAIStreamAccumulator:
if raw_arguments if raw_arguments
else None else None
) )
call_id = state.call_id or f"tool_call_{index}" call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}")
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
response.raw_data = {"model": self.model_name} if self.model_name else None response.raw_data = {"model": self.model_name} if self.model_name else None

View File

@@ -4,16 +4,18 @@ import json
import time import time
import traceback import traceback
from datetime import datetime from datetime import datetime
from typing import Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import delete, func, select from sqlalchemy import delete, func
from sqlmodel import SQLModel from sqlmodel import SQLModel, select
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import ToolRecord from src.common.database.database_model import ToolRecord
from src.common.logger import get_logger from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("database_service") logger = get_logger("database_service")
@@ -158,7 +160,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
async def store_tool_info( async def store_tool_info(
chat_stream: BotChatSession, chat_stream: "BotChatSession",
builtin_prompt: Optional[str] = None, builtin_prompt: Optional[str] = None,
display_prompt: str = "", display_prompt: str = "",
tool_id: str = "", tool_id: str = "",
@@ -191,7 +193,7 @@ async def store_tool_info(
async def store_action_info( async def store_action_info(
chat_stream: BotChatSession, chat_stream: "BotChatSession",
builtin_prompt: Optional[str] = None, builtin_prompt: Optional[str] = None,
display_prompt: str = "", display_prompt: str = "",
thinking_id: str = "", thinking_id: str = "",

View File

@@ -227,11 +227,6 @@ class MemoryAutomationService:
await self.fact_writeback.shutdown() await self.fact_writeback.shutdown()
self._started = False self._started = False
async def on_incoming_message(self, message: Any) -> None:
del message
if not self._started:
await self.start()
async def on_message_sent(self, message: Any) -> None: async def on_message_sent(self, message: Any) -> None:
if not self._started: if not self._started:
await self.start() await self.start()