feat(llm): 添加响应格式转换功能,支持JSON_SCHEMA输出
This commit is contained in:
@@ -33,12 +33,34 @@ from ..exceptions import (
|
|||||||
EmptyResponseException,
|
EmptyResponseException,
|
||||||
)
|
)
|
||||||
from ..payload_content.message import Message, RoleType
|
from ..payload_content.message import Message, RoleType
|
||||||
from ..payload_content.resp_format import RespFormat
|
from ..payload_content.resp_format import RespFormat, RespFormatType
|
||||||
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
|
||||||
|
|
||||||
logger = get_logger("llm_models")
|
logger = get_logger("llm_models")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_response_format(response_format: RespFormat | None) -> Any:
|
||||||
|
"""
|
||||||
|
转换响应格式 - 将内部RespFormat转换为OpenAI API所需格式
|
||||||
|
"""
|
||||||
|
if response_format is None:
|
||||||
|
return NOT_GIVEN
|
||||||
|
|
||||||
|
if response_format.format_type == RespFormatType.TEXT:
|
||||||
|
return NOT_GIVEN
|
||||||
|
|
||||||
|
if response_format.format_type == RespFormatType.JSON_OBJ:
|
||||||
|
return {"type": "json_object"}
|
||||||
|
|
||||||
|
if response_format.format_type == RespFormatType.JSON_SCHEMA:
|
||||||
|
return {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": response_format.schema,
|
||||||
|
}
|
||||||
|
|
||||||
|
return NOT_GIVEN
|
||||||
|
|
||||||
|
|
||||||
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
|
||||||
"""
|
"""
|
||||||
转换消息格式 - 将消息转换为OpenAI API所需的格式
|
转换消息格式 - 将消息转换为OpenAI API所需的格式
|
||||||
@@ -553,6 +575,7 @@ class OpenaiClient(BaseClient):
|
|||||||
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
|
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
|
||||||
# 将tool_options转换为OpenAI API所需的格式
|
# 将tool_options转换为OpenAI API所需的格式
|
||||||
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
|
||||||
|
openai_response_format = _convert_response_format(response_format)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
@@ -564,7 +587,7 @@ class OpenaiClient(BaseClient):
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=True,
|
stream=True,
|
||||||
response_format=NOT_GIVEN,
|
response_format=openai_response_format,
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -587,7 +610,7 @@ class OpenaiClient(BaseClient):
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=False,
|
stream=False,
|
||||||
response_format=NOT_GIVEN,
|
response_format=openai_response_format,
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import re
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import json
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -12,7 +13,7 @@ from src.common.logger import get_logger
|
|||||||
from src.config.config import config_manager
|
from src.config.config import config_manager
|
||||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||||
from .payload_content.message import MessageBuilder, Message
|
from .payload_content.message import MessageBuilder, Message
|
||||||
from .payload_content.resp_format import RespFormat
|
from .payload_content.resp_format import RespFormat, RespFormatType
|
||||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||||
from .utils import compress_messages, llm_usage_recorder
|
from .utils import compress_messages, llm_usage_recorder
|
||||||
@@ -170,6 +171,7 @@ class LLMRequest:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
raise_when_empty: bool = True,
|
raise_when_empty: bool = True,
|
||||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
"""
|
"""
|
||||||
@@ -179,6 +181,7 @@ class LLMRequest:
|
|||||||
temperature (float, optional): 温度参数
|
temperature (float, optional): 温度参数
|
||||||
max_tokens (int, optional): 最大token数
|
max_tokens (int, optional): 最大token数
|
||||||
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
||||||
|
response_format (RespFormat | None): 响应格式
|
||||||
raise_when_empty (bool): 当响应为空时是否抛出异常
|
raise_when_empty (bool): 当响应为空时是否抛出异常
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
@@ -199,6 +202,7 @@ class LLMRequest:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tool_options=tool_built,
|
tool_options=tool_built,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
|
||||||
@@ -227,6 +231,7 @@ class LLMRequest:
|
|||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
response_format: RespFormat | None = None,
|
||||||
raise_when_empty: bool = True,
|
raise_when_empty: bool = True,
|
||||||
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
|
||||||
"""
|
"""
|
||||||
@@ -236,6 +241,7 @@ class LLMRequest:
|
|||||||
temperature (float, optional): 温度参数
|
temperature (float, optional): 温度参数
|
||||||
max_tokens (int, optional): 最大token数
|
max_tokens (int, optional): 最大token数
|
||||||
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
tools (Optional[List[Dict[str, Any]]]): 工具列表
|
||||||
|
response_format (RespFormat | None): 响应格式
|
||||||
raise_when_empty (bool): 当响应为空时是否抛出异常
|
raise_when_empty (bool): 当响应为空时是否抛出异常
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||||
@@ -251,6 +257,7 @@ class LLMRequest:
|
|||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
tool_options=tool_built,
|
tool_options=tool_built,
|
||||||
|
response_format=response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_cost = time.time() - start_time
|
time_cost = time.time() - start_time
|
||||||
@@ -275,6 +282,100 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||||
|
|
||||||
|
async def generate_structured_response_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
schema: type | dict[str, Any],
|
||||||
|
fallback_result: dict[str, Any] | None = None,
|
||||||
|
temperature: Optional[float] = 0.0,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
) -> Tuple[dict[str, Any], Tuple[str, str, Optional[List[ToolCall]]], bool]:
|
||||||
|
"""
|
||||||
|
结构化输出快速接口:
|
||||||
|
- 默认启用 JSON_SCHEMA 严格模式
|
||||||
|
- 单模型单次尝试(不重试、不切换模型)
|
||||||
|
- 失败时立即返回 fallback_result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(结构化结果, (推理内容, 模型名, 工具调用), 是否成功)
|
||||||
|
"""
|
||||||
|
self._refresh_task_config()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
message_builder = MessageBuilder()
|
||||||
|
message_builder.add_text_content(prompt)
|
||||||
|
message_list = [message_builder.build()]
|
||||||
|
|
||||||
|
response_format = RespFormat(schema=schema, format_type=RespFormatType.JSON_SCHEMA)
|
||||||
|
if response_format.schema:
|
||||||
|
response_format.schema["strict"] = True
|
||||||
|
|
||||||
|
model_info, api_provider, client = self._select_model()
|
||||||
|
fallback_data = fallback_result or {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self._attempt_request_on_model(
|
||||||
|
model_info=model_info,
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
message_list=message_list,
|
||||||
|
tool_options=None,
|
||||||
|
response_format=response_format,
|
||||||
|
stream_response_handler=None,
|
||||||
|
async_response_parser=None,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
embedding_input=None,
|
||||||
|
audio_base64=None,
|
||||||
|
retry_limit=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
time_cost = time.time() - start_time
|
||||||
|
self._check_slow_request(time_cost, model_info.name)
|
||||||
|
|
||||||
|
reasoning_content = response.reasoning_content or ""
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
parsed_result: dict[str, Any] | None = None
|
||||||
|
if response.content:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(response.content)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
parsed_result = parsed
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
parsed_result = None
|
||||||
|
|
||||||
|
if parsed_result is None:
|
||||||
|
logger.warning(f"结构化输出解析失败,使用降级结果。模型: {model_info.name}")
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0))
|
||||||
|
return fallback_data, (reasoning_content, model_info.name, tool_calls), False
|
||||||
|
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
if response_usage := response.usage:
|
||||||
|
total_tokens += response_usage.total_tokens
|
||||||
|
llm_usage_recorder.record_usage_to_database(
|
||||||
|
model_info=model_info,
|
||||||
|
model_usage=response_usage,
|
||||||
|
user_id="system",
|
||||||
|
request_type=self.request_type,
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
time_cost=time_cost,
|
||||||
|
)
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty, max(usage_penalty - 1, 0))
|
||||||
|
return parsed_result, (reasoning_content, model_info.name, tool_calls), True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
time_cost = time.time() - start_time
|
||||||
|
self._check_slow_request(time_cost, model_info.name)
|
||||||
|
logger.warning(f"结构化输出请求失败,直接降级。模型: {model_info.name}, 错误: {e}")
|
||||||
|
|
||||||
|
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||||
|
self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0))
|
||||||
|
|
||||||
|
return fallback_data, ("", model_info.name, None), False
|
||||||
|
|
||||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||||
"""
|
"""
|
||||||
获取嵌入向量
|
获取嵌入向量
|
||||||
@@ -359,12 +460,14 @@ class LLMRequest:
|
|||||||
max_tokens: Optional[int],
|
max_tokens: Optional[int],
|
||||||
embedding_input: str | None,
|
embedding_input: str | None,
|
||||||
audio_base64: str | None,
|
audio_base64: str | None,
|
||||||
|
retry_limit: Optional[int] = None,
|
||||||
) -> APIResponse:
|
) -> APIResponse:
|
||||||
"""
|
"""
|
||||||
在单个模型上执行请求,包含针对临时错误的重试逻辑。
|
在单个模型上执行请求,包含针对临时错误的重试逻辑。
|
||||||
如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。
|
如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。
|
||||||
"""
|
"""
|
||||||
retry_remain = api_provider.max_retry
|
retry_remain = retry_limit if retry_limit is not None else api_provider.max_retry
|
||||||
|
retry_remain = max(1, retry_remain)
|
||||||
compressed_messages: Optional[List[Message]] = None
|
compressed_messages: Optional[List[Message]] = None
|
||||||
|
|
||||||
while retry_remain > 0:
|
while retry_remain > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user