fix:工具调用存储问题
This commit is contained in:
@@ -6,6 +6,7 @@ import json
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
from openai import APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
|
||||||
@@ -119,6 +120,13 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
|
|||||||
"""OpenAI 非流式响应解析函数类型。"""
|
"""OpenAI 非流式响应解析函数类型。"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_fallback_tool_call_id(prefix: str) -> str:
|
||||||
|
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
|
||||||
|
|
||||||
|
normalized_prefix = str(prefix).strip() or "tool_call"
|
||||||
|
return f"{normalized_prefix}_{uuid4().hex}"
|
||||||
|
|
||||||
|
|
||||||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||||||
"""将配置中的推理解析模式收敛为枚举值。
|
"""将配置中的推理解析模式收敛为枚举值。
|
||||||
|
|
||||||
@@ -609,7 +617,7 @@ def _extract_xml_tool_calls(
|
|||||||
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {}
|
||||||
tool_calls.append(
|
tool_calls.append(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=f"xml_tool_call_{len(tool_calls) + 1}",
|
call_id=_build_fallback_tool_call_id("xml_tool_call"),
|
||||||
func_name=function_name,
|
func_name=function_name,
|
||||||
args=arguments,
|
args=arguments,
|
||||||
)
|
)
|
||||||
@@ -855,7 +863,7 @@ class _OpenAIStreamAccumulator:
|
|||||||
if raw_arguments
|
if raw_arguments
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
call_id = state.call_id or f"tool_call_{index}"
|
call_id = state.call_id or _build_fallback_tool_call_id(f"tool_call_{index}")
|
||||||
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments))
|
||||||
|
|
||||||
response.raw_data = {"model": self.model_name} if self.model_name else None
|
response.raw_data = {"model": self.model_name} if self.model_name else None
|
||||||
|
|||||||
@@ -4,16 +4,18 @@ import json
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel, select
|
||||||
|
|
||||||
from src.chat.message_receive.chat_manager import BotChatSession
|
|
||||||
from src.common.database.database import get_db_session
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import ToolRecord
|
from src.common.database.database_model import ToolRecord
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_manager import BotChatSession
|
||||||
|
|
||||||
logger = get_logger("database_service")
|
logger = get_logger("database_service")
|
||||||
|
|
||||||
|
|
||||||
@@ -158,7 +160,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
|||||||
|
|
||||||
|
|
||||||
async def store_tool_info(
|
async def store_tool_info(
|
||||||
chat_stream: BotChatSession,
|
chat_stream: "BotChatSession",
|
||||||
builtin_prompt: Optional[str] = None,
|
builtin_prompt: Optional[str] = None,
|
||||||
display_prompt: str = "",
|
display_prompt: str = "",
|
||||||
tool_id: str = "",
|
tool_id: str = "",
|
||||||
@@ -191,7 +193,7 @@ async def store_tool_info(
|
|||||||
|
|
||||||
|
|
||||||
async def store_action_info(
|
async def store_action_info(
|
||||||
chat_stream: BotChatSession,
|
chat_stream: "BotChatSession",
|
||||||
builtin_prompt: Optional[str] = None,
|
builtin_prompt: Optional[str] = None,
|
||||||
display_prompt: str = "",
|
display_prompt: str = "",
|
||||||
thinking_id: str = "",
|
thinking_id: str = "",
|
||||||
|
|||||||
@@ -227,11 +227,6 @@ class MemoryAutomationService:
|
|||||||
await self.fact_writeback.shutdown()
|
await self.fact_writeback.shutdown()
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
async def on_incoming_message(self, message: Any) -> None:
|
|
||||||
del message
|
|
||||||
if not self._started:
|
|
||||||
await self.start()
|
|
||||||
|
|
||||||
async def on_message_sent(self, message: Any) -> None:
|
async def on_message_sent(self, message: Any) -> None:
|
||||||
if not self._started:
|
if not self._started:
|
||||||
await self.start()
|
await self.start()
|
||||||
|
|||||||
Reference in New Issue
Block a user