diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 07dd66a2..99efe8d9 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -33,12 +33,34 @@ from ..exceptions import ( EmptyResponseException, ) from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat +from ..payload_content.resp_format import RespFormat, RespFormatType from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall logger = get_logger("llm_models") +def _convert_response_format(response_format: RespFormat | None) -> Any: + """ + 转换响应格式 - 将内部RespFormat转换为OpenAI API所需格式 + """ + if response_format is None: + return NOT_GIVEN + + if response_format.format_type == RespFormatType.TEXT: + return NOT_GIVEN + + if response_format.format_type == RespFormatType.JSON_OBJ: + return {"type": "json_object"} + + if response_format.format_type == RespFormatType.JSON_SCHEMA: + return { + "type": "json_schema", + "json_schema": response_format.schema, + } + + return NOT_GIVEN + + def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: """ 转换消息格式 - 将消息转换为OpenAI API所需的格式 @@ -553,6 +575,7 @@ class OpenaiClient(BaseClient): messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) # 将tool_options转换为OpenAI API所需的格式 tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore + openai_response_format = _convert_response_format(response_format) try: if model_info.force_stream_mode: @@ -564,7 +587,7 @@ class OpenaiClient(BaseClient): temperature=temperature, max_tokens=max_tokens, stream=True, - response_format=NOT_GIVEN, + response_format=openai_response_format, extra_body=extra_params, ) ) @@ -587,7 +610,7 @@ class OpenaiClient(BaseClient): temperature=temperature, max_tokens=max_tokens, stream=False, - response_format=NOT_GIVEN, + response_format=openai_response_format, extra_body=extra_params, ) ) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index ad0460f3..457b12c6 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -2,6 +2,7 @@ import re import asyncio import time import random +import json from enum import Enum from rich.traceback import install @@ -12,7 +13,7 @@ from src.common.logger import get_logger from src.config.config import config_manager from src.config.model_configs import APIProvider, ModelInfo, TaskConfig from .payload_content.message import MessageBuilder, Message -from .payload_content.resp_format import RespFormat +from .payload_content.resp_format import RespFormat, RespFormatType from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType from .model_client.base_client import BaseClient, APIResponse, client_registry from .utils import compress_messages, llm_usage_recorder @@ -170,6 +171,7 @@ class LLMRequest: temperature: Optional[float] = None, max_tokens: Optional[int] = None, tools: Optional[List[Dict[str, Any]]] = None, + response_format: RespFormat | None = None, raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ @@ -179,6 +181,7 @@ class LLMRequest: temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 tools (Optional[List[Dict[str, Any]]]): 工具列表 + response_format (RespFormat | None): 响应格式 raise_when_empty (bool): 当响应为空时是否抛出异常 Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 @@ -199,6 +202,7 @@ class LLMRequest: temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, + response_format=response_format, ) logger.debug(f"LLM请求总耗时: {time.time() - start_time}") @@ -227,6 +231,7 @@ class LLMRequest: temperature: Optional[float] = None, max_tokens: Optional[int] = None, tools: Optional[List[Dict[str, Any]]] = None, + response_format: RespFormat | None = None, raise_when_empty: bool = True, ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ @@ -236,6 +241,7 @@ class LLMRequest: temperature (float, optional): 温度参数 max_tokens (int, optional): 最大token数 tools (Optional[List[Dict[str, Any]]]): 工具列表 + response_format (RespFormat | None): 响应格式 raise_when_empty (bool): 当响应为空时是否抛出异常 Returns: (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 @@ -251,6 +257,7 @@ class LLMRequest: temperature=temperature, max_tokens=max_tokens, tool_options=tool_built, + response_format=response_format, ) time_cost = time.time() - start_time @@ -275,6 +282,100 @@ class LLMRequest: ) return content or "", (reasoning_content, model_info.name, tool_calls) + async def generate_structured_response_async( + self, + prompt: str, + schema: type | dict[str, Any], + fallback_result: dict[str, Any] | None = None, + temperature: Optional[float] = 0.0, + max_tokens: Optional[int] = None, + ) -> Tuple[dict[str, Any], Tuple[str, str, Optional[List[ToolCall]]], bool]: + """ + 结构化输出快速接口: + - 默认启用 JSON_SCHEMA 严格模式 + - 单模型单次尝试(不重试、不切换模型) + - 失败时立即返回 fallback_result + + Returns: + (结构化结果, (推理内容, 模型名, 工具调用), 是否成功) + """ + self._refresh_task_config() + start_time = time.time() + + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_list = [message_builder.build()] + + response_format = RespFormat(schema=schema, format_type=RespFormatType.JSON_SCHEMA) + if response_format.schema: + response_format.schema["strict"] = True + + model_info, api_provider, client = self._select_model() + fallback_data = fallback_result or {} + + try: + response = await self._attempt_request_on_model( + model_info=model_info, + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + message_list=message_list, + tool_options=None, + response_format=response_format, + stream_response_handler=None, + async_response_parser=None, + temperature=temperature, + max_tokens=max_tokens, + embedding_input=None, + audio_base64=None, + retry_limit=1, + ) + + time_cost = time.time() - start_time + self._check_slow_request(time_cost, model_info.name) + + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + + parsed_result: dict[str, Any] | None = None + if response.content: + try: + parsed = json.loads(response.content) + if isinstance(parsed, dict): + parsed_result = parsed + except json.JSONDecodeError: + parsed_result = None + + if parsed_result is None: + logger.warning(f"结构化输出解析失败,使用降级结果。模型: {model_info.name}") + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) + return fallback_data, (reasoning_content, model_info.name, tool_calls), False + + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + if response_usage := response.usage: + total_tokens += response_usage.total_tokens + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=response_usage, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + time_cost=time_cost, + ) + self.model_usage[model_info.name] = (total_tokens, penalty, max(usage_penalty - 1, 0)) + return parsed_result, (reasoning_content, model_info.name, tool_calls), True + + except Exception as e: + time_cost = time.time() - start_time + self._check_slow_request(time_cost, model_info.name) + logger.warning(f"结构化输出请求失败,直接降级。模型: {model_info.name}, 错误: {e}") + + total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1, max(usage_penalty - 1, 0)) + + return fallback_data, ("", model_info.name, None), False + async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: """ 获取嵌入向量 @@ -359,12 +460,14 @@ class LLMRequest: max_tokens: Optional[int], embedding_input: str | None, audio_base64: str | None, + retry_limit: Optional[int] = None, ) -> APIResponse: """ 在单个模型上执行请求,包含针对临时错误的重试逻辑。 如果成功,返回APIResponse。如果失败(重试耗尽或硬错误),则抛出ModelAttemptFailed异常。 """ - retry_remain = api_provider.max_retry + retry_remain = retry_limit if retry_limit is not None else api_provider.max_retry + retry_remain = max(1, retry_remain) compressed_messages: Optional[List[Message]] = None while retry_remain > 0: