feat(llm): 添加响应格式转换功能,支持JSON_SCHEMA输出

This commit is contained in:
DrSmoothl
2026-03-04 21:11:10 +08:00
parent 81bc25dba8
commit 5cccdf6715
2 changed files with 131 additions and 5 deletions

View File

@@ -33,12 +33,34 @@ from ..exceptions import (
EmptyResponseException, EmptyResponseException,
) )
from ..payload_content.message import Message, RoleType from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("llm_models") logger = get_logger("llm_models")
def _convert_response_format(response_format: RespFormat | None) -> Any:
"""
转换响应格式 - 将内部RespFormat转换为OpenAI API所需格式
"""
if response_format is None:
return NOT_GIVEN
if response_format.format_type == RespFormatType.TEXT:
return NOT_GIVEN
if response_format.format_type == RespFormatType.JSON_OBJ:
return {"type": "json_object"}
if response_format.format_type == RespFormatType.JSON_SCHEMA:
return {
"type": "json_schema",
"json_schema": response_format.schema,
}
return NOT_GIVEN
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
""" """
转换消息格式 - 将消息转换为OpenAI API所需的格式 转换消息格式 - 将消息转换为OpenAI API所需的格式
@@ -553,6 +575,7 @@ class OpenaiClient(BaseClient):
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式 # 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
openai_response_format = _convert_response_format(response_format)
try: try:
if model_info.force_stream_mode: if model_info.force_stream_mode:
@@ -564,7 +587,7 @@ class OpenaiClient(BaseClient):
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
response_format=NOT_GIVEN, response_format=openai_response_format,
extra_body=extra_params, extra_body=extra_params,
) )
) )
@@ -587,7 +610,7 @@ class OpenaiClient(BaseClient):
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, stream=False,
response_format=NOT_GIVEN, response_format=openai_response_format,
extra_body=extra_params, extra_body=extra_params,
) )
) )

View File

@@ -2,6 +2,7 @@ import re
import asyncio import asyncio
import time import time
import random import random
import json
from enum import Enum from enum import Enum
from rich.traceback import install from rich.traceback import install
@@ -12,7 +13,7 @@ from src.common.logger import get_logger
from src.config.config import config_manager from src.config.config import config_manager
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat from .payload_content.resp_format import RespFormat, RespFormatType
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .model_client.base_client import BaseClient, APIResponse, client_registry from .model_client.base_client import BaseClient, APIResponse, client_registry
from .utils import compress_messages, llm_usage_recorder from .utils import compress_messages, llm_usage_recorder
@@ -170,6 +171,7 @@ class LLMRequest:
temperature: Optional[float] = None, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None,
response_format: RespFormat | None = None,
raise_when_empty: bool = True, raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
""" """
@@ -179,6 +181,7 @@ class LLMRequest:
temperature (float, optional): 温度参数 temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数 max_tokens (int, optional): 最大token数
tools (Optional[List[Dict[str, Any]]]): 工具列表 tools (Optional[List[Dict[str, Any]]]): 工具列表
response_format (RespFormat | None): 响应格式
raise_when_empty (bool): 当响应为空时是否抛出异常 raise_when_empty (bool): 当响应为空时是否抛出异常
Returns: Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
@@ -199,6 +202,7 @@ class LLMRequest:
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
tool_options=tool_built, tool_options=tool_built,
response_format=response_format,
) )
logger.debug(f"LLM请求总耗时: {time.time() - start_time}") logger.debug(f"LLM请求总耗时: {time.time() - start_time}")
@@ -227,6 +231,7 @@ class LLMRequest:
temperature: Optional[float] = None, temperature: Optional[float] = None,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
tools: Optional[List[Dict[str, Any]]] = None, tools: Optional[List[Dict[str, Any]]] = None,
response_format: RespFormat | None = None,
raise_when_empty: bool = True, raise_when_empty: bool = True,
) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
""" """
@@ -236,6 +241,7 @@ class LLMRequest:
temperature (float, optional): 温度参数 temperature (float, optional): 温度参数
max_tokens (int, optional): 最大token数 max_tokens (int, optional): 最大token数
tools (Optional[List[Dict[str, Any]]]): 工具列表 tools (Optional[List[Dict[str, Any]]]): 工具列表
response_format (RespFormat | None): 响应格式
raise_when_empty (bool): 当响应为空时是否抛出异常 raise_when_empty (bool): 当响应为空时是否抛出异常
Returns: Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
@@ -251,6 +257,7 @@ class LLMRequest:
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
tool_options=tool_built, tool_options=tool_built,
response_format=response_format,
) )
time_cost = time.time() - start_time time_cost = time.time() - start_time
@@ -275,6 +282,100 @@ class LLMRequest:
) )
return content or "", (reasoning_content, model_info.name, tool_calls) return content or "", (reasoning_content, model_info.name, tool_calls)
async def generate_structured_response_async(
self,
prompt: str,
schema: type | dict[str, Any],
fallback_result: dict[str, Any] | None = None,
temperature: Optional[float] = 0.0,
max_tokens: Optional[int] = None,
) -> Tuple[dict[str, Any], Tuple[str, str, Optional[List[ToolCall]]], bool]:
"""
结构化输出快速接口:
- 默认启用 JSON_SCHEMA 严格模式
- 单模型单次尝试(不重试、不切换模型)
- 失败时立即返回 fallback_result
Returns:
(结构化结果, (推理内容, 模型名, 工具调用), 是否成功)
"""
self._refresh_task_config()
start_time = time.time()
message_builder = MessageBuilder()
message_builder.add_text_content(prompt)
message_list = [message_builder.build()]
response_format = RespFormat(schema=schema, format_type=RespFormatType.JSON_SCHEMA)
if response_format.schema:
response_format.schema["strict"] = True
model_info, api_provider, client = self._select_model()
fallback_data = fallback_result or {}
try:
response = await self._attempt_request_on_model(
model_info=model_info,
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
message_list=message_list,
tool_options=None,
response_format=response_format,
stream_response_handler=None,
async_response_parser=None,
temperature=temperature,
max_tokens=max_tokens,
embedding_input=None,
audio_base64=None,
retry_limit=1,
)
time_cost = time.time() - start_time
self._check_slow_request(time_cost, model_info.name)
reasoning_content = response.reasoning_content or ""
tool_calls = response.tool_calls
parsed_result: dict[str, Any] | None = None
if response.content:
try:
parsed = json.loads(response.content)
if isinstance(parsed, dict):
parsed_result = parsed
except json.JSONDecodeError:
parsed_result = None
if parsed_result is None:
logger.warning(f"结构化输出解析失败,使用降级结果。模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0))
return fallback_data, (reasoning_content, model_info.name, tool_calls), False
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
if response_usage := response.usage:
total_tokens += response_usage.total_tokens
llm_usage_recorder.record_usage_to_database(
model_info=model_info,
model_usage=response_usage,
user_id="system",
request_type=self.request_type,
endpoint="/chat/completions",
time_cost=time_cost,
)
self.model_usage[model_info.name] = (total_tokens, penalty, max(usage_penalty - 1, 0))
return parsed_result, (reasoning_content, model_info.name, tool_calls), True
except Exception as e:
time_cost = time.time() - start_time
self._check_slow_request(time_cost, model_info.name)
logger.warning(f"结构化输出请求失败,直接降级。模型: {model_info.name}, 错误: {e}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0))
return fallback_data, ("", model_info.name, None), False
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
""" """
获取嵌入向量 获取嵌入向量
@@ -359,12 +460,14 @@ class LLMRequest:
max_tokens: Optional[int], max_tokens: Optional[int],
embedding_input: str | None, embedding_input: str | None,
audio_base64: str | None, audio_base64: str | None,
retry_limit: Optional[int] = None,
) -> APIResponse: ) -> APIResponse:
""" """
在单个模型上执行请求,包含针对临时错误的重试逻辑。 在单个模型上执行请求,包含针对临时错误的重试逻辑。
如果成功返回APIResponse。如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常。 如果成功返回APIResponse。如果失败重试耗尽或硬错误则抛出ModelAttemptFailed异常。
""" """
retry_remain = api_provider.max_retry retry_remain = retry_limit if retry_limit is not None else api_provider.max_retry
retry_remain = max(1, retry_remain)
compressed_messages: Optional[List[Message]] = None compressed_messages: Optional[List[Message]] = None
while retry_remain > 0: while retry_remain > 0: