diff --git a/pytests/test_openai_client_toolless_request.py b/pytests/test_openai_client_toolless_request.py index 2e1748b7..eb6f6ae9 100644 --- a/pytests/test_openai_client_toolless_request.py +++ b/pytests/test_openai_client_toolless_request.py @@ -1,4 +1,12 @@ -from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request +from types import SimpleNamespace + +from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode +from src.llm_models.model_client.openai_client import ( + _OpenAIStreamAccumulator, + _build_reasoning_key, + _default_normal_response_parser, + _sanitize_messages_for_toolless_request, +) from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart from src.llm_models.payload_content.tool_option import ToolCall @@ -25,3 +33,89 @@ def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_withou assert len(sanitized_messages) == 1 assert sanitized_messages[0].role == RoleType.User + + +def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None: + response = SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + content="正式回复", + reasoning="推理内容", + tool_calls=None, + ), + ) + ], + usage=None, + model="openrouter/test-model", + ) + + api_response, usage_record = _default_normal_response_parser( + response, + reasoning_parse_mode=ReasoningParseMode.AUTO, + tool_argument_parse_mode=ToolArgumentParseMode.AUTO, + reasoning_key=_build_reasoning_key( + APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test") + ), + ) + + assert api_response.content == "正式回复" + assert api_response.reasoning_content is None + assert usage_record is None + + +def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None: + provider_urls = [ + "https://openrouter.ai/compatible-api", + "https://api.groq.com/openai/v1", + ] + + for provider_url in provider_urls: + response = SimpleNamespace( + choices=[ + SimpleNamespace( + finish_reason="stop", + message=SimpleNamespace( + content="正式回复", + reasoning="推理内容", + tool_calls=None, + ), + ) + ], + usage=None, + model="test-model", + ) + + api_response, usage_record = _default_normal_response_parser( + response, + reasoning_parse_mode=ReasoningParseMode.AUTO, + tool_argument_parse_mode=ToolArgumentParseMode.AUTO, + reasoning_key=_build_reasoning_key( + APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test") + ), + ) + + assert api_response.content == "正式回复" + assert api_response.reasoning_content == "推理内容" + assert usage_record is None + + +def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None: + accumulator = _OpenAIStreamAccumulator( + reasoning_parse_mode=ReasoningParseMode.AUTO, + tool_argument_parse_mode=ToolArgumentParseMode.AUTO, + reasoning_key=_build_reasoning_key( + APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test") + ), + ) + try: + accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None)) + accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None)) + + api_response = accumulator.build_response() + finally: + accumulator.close() + + assert api_response.content == "正式回复" + assert api_response.reasoning_content == "流式推理" diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index d7307101..02611907 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast +from urllib.parse import urlparse from uuid import uuid4 from json_repair import repair_json @@ -123,6 +124,12 @@ OpenAIStreamResponseHandler = Callable[ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]] """OpenAI 非流式响应解析函数类型。""" +PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = { + "api.groq.com": "reasoning", + "openrouter.ai": "reasoning", +} +"""按 provider 域名指定的原生推理字段名。""" + def _build_debug_provider_request_filename(model_name: str) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") @@ -156,6 +163,28 @@ def _build_fallback_tool_call_id(prefix: str) -> str: return f"{normalized_prefix}_{uuid4().hex}" +def _build_reasoning_key(api_provider: APIProvider) -> str: + """根据 provider 构建可读取的原生推理字段名。""" + normalized_base_url = api_provider.base_url.strip() + if not normalized_base_url: + return "reasoning_content" + parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}") + provider_hostname = (parsed_url.hostname or "").lower() + return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content") + + +def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None: + """从 OpenAI 兼容响应对象中读取原生推理内容。 + + 不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问, + 避免解析路径里散落 provider 特判;具体字段名由 provider 决定。 + """ + native_reasoning = getattr(message_part, reasoning_key, None) + if isinstance(native_reasoning, str) and native_reasoning: + return native_reasoning + return None + + def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: """将配置中的推理解析模式收敛为枚举值。 @@ -781,15 +810,18 @@ class _OpenAIStreamAccumulator: self, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, + reasoning_key: str, ) -> None: """初始化累积器。 Args: reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 + reasoning_key: 允许读取的原生推理字段名。 """ self.reasoning_parse_mode = reasoning_parse_mode self.tool_argument_parse_mode = tool_argument_parse_mode + self.reasoning_key = reasoning_key self.reasoning_buffer = io.StringIO() self.content_buffer = io.StringIO() self.tool_call_states: Dict[int, _StreamedToolCallState] = {} @@ -825,8 +857,8 @@ class _OpenAIStreamAccumulator: Args: delta: 当前增量对象。 """ - native_reasoning = getattr(delta, "reasoning_content", None) - if isinstance(native_reasoning, str) and native_reasoning: + native_reasoning = _extract_reasoning_content(delta, self.reasoning_key) + if native_reasoning is not None: self._using_native_reasoning = True if self.reasoning_parse_mode != ReasoningParseMode.NONE: self.reasoning_buffer.write(native_reasoning) @@ -929,6 +961,7 @@ async def _default_stream_response_handler( *, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, + reasoning_key: str, ) -> Tuple[APIResponse, UsageTuple | None]: """处理 OpenAI 兼容流式响应。 @@ -937,6 +970,7 @@ async def _default_stream_response_handler( interrupt_flag: 外部中断标记。 reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 + reasoning_key: 允许读取的原生推理字段名。 Returns: Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 @@ -944,6 +978,7 @@ async def _default_stream_response_handler( accumulator = _OpenAIStreamAccumulator( reasoning_parse_mode=reasoning_parse_mode, tool_argument_parse_mode=tool_argument_parse_mode, + reasoning_key=reasoning_key, ) usage_record: UsageTuple | None = None @@ -977,6 +1012,7 @@ def _default_normal_response_parser( *, reasoning_parse_mode: ReasoningParseMode, tool_argument_parse_mode: ToolArgumentParseMode, + reasoning_key: str, ) -> Tuple[APIResponse, UsageTuple | None]: """解析 OpenAI 兼容的非流式响应。 @@ -984,6 +1020,7 @@ def _default_normal_response_parser( resp: OpenAI SDK 返回的聊天补全响应。 reasoning_parse_mode: 推理内容解析模式。 tool_argument_parse_mode: 工具参数解析模式。 + reasoning_key: 允许读取的原生推理字段名。 Returns: Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 @@ -997,10 +1034,10 @@ def _default_normal_response_parser( api_response = APIResponse() message_part = choices[0].message - native_reasoning = getattr(message_part, "reasoning_content", None) + native_reasoning = _extract_reasoning_content(message_part, reasoning_key) message_content = message_part.content if isinstance(message_part.content, str) else None - if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE: + if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE: api_response.reasoning_content = native_reasoning api_response.content = message_content elif isinstance(message_content, str) and message_content: @@ -1046,6 +1083,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio client: AsyncOpenAI reasoning_parse_mode: ReasoningParseMode + reasoning_key: str tool_argument_parse_mode: ToolArgumentParseMode def __init__(self, api_provider: APIProvider) -> None: @@ -1057,6 +1095,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio super().__init__(api_provider) client_config = build_openai_compatible_client_config(api_provider) self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode) + self.reasoning_key = _build_reasoning_key(api_provider) self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode) self.client = AsyncOpenAI( api_key=client_config.api_key, @@ -1093,6 +1132,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio flag, reasoning_parse_mode=self.reasoning_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode, + reasoning_key=self.reasoning_key, ) return default_stream_handler @@ -1119,6 +1159,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio response, reasoning_parse_mode=self.reasoning_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode, + reasoning_key=self.reasoning_key, ) return default_response_parser