fix(llm): support reasoning field for OpenRouter and Groq

fixes: #1600
This commit is contained in:
Soulter
2026-04-26 22:45:37 +08:00
parent be2248b283
commit 45cd00e343
2 changed files with 140 additions and 5 deletions

View File

@@ -1,4 +1,12 @@
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request from types import SimpleNamespace
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
from src.llm_models.model_client.openai_client import (
_OpenAIStreamAccumulator,
_build_reasoning_key,
_default_normal_response_parser,
_sanitize_messages_for_toolless_request,
)
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.tool_option import ToolCall
@@ -25,3 +33,89 @@ def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_withou
assert len(sanitized_messages) == 1 assert len(sanitized_messages) == 1
assert sanitized_messages[0].role == RoleType.User assert sanitized_messages[0].role == RoleType.User
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="openrouter/test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content is None
assert usage_record is None
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
provider_urls = [
"https://openrouter.ai/compatible-api",
"https://api.groq.com/openai/v1",
]
for provider_url in provider_urls:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "推理内容"
assert usage_record is None
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
),
)
try:
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
api_response = accumulator.build_response()
finally:
accumulator.close()
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "流式推理"

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from json_repair import repair_json from json_repair import repair_json
@@ -123,6 +124,12 @@ OpenAIStreamResponseHandler = Callable[
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]] OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
"""OpenAI 非流式响应解析函数类型。""" """OpenAI 非流式响应解析函数类型。"""
PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = {
"api.groq.com": "reasoning",
"openrouter.ai": "reasoning",
}
"""按 provider 域名指定的原生推理字段名。"""
def _build_debug_provider_request_filename(model_name: str) -> str: def _build_debug_provider_request_filename(model_name: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
@@ -156,6 +163,28 @@ def _build_fallback_tool_call_id(prefix: str) -> str:
return f"{normalized_prefix}_{uuid4().hex}" return f"{normalized_prefix}_{uuid4().hex}"
def _build_reasoning_key(api_provider: APIProvider) -> str:
"""根据 provider 构建可读取的原生推理字段名。"""
normalized_base_url = api_provider.base_url.strip()
if not normalized_base_url:
return "reasoning_content"
parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}")
provider_hostname = (parsed_url.hostname or "").lower()
return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content")
def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None:
"""从 OpenAI 兼容响应对象中读取原生推理内容。
不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问,
避免解析路径里散落 provider 特判;具体字段名由 provider 决定。
"""
native_reasoning = getattr(message_part, reasoning_key, None)
if isinstance(native_reasoning, str) and native_reasoning:
return native_reasoning
return None
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode: def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
"""将配置中的推理解析模式收敛为枚举值。 """将配置中的推理解析模式收敛为枚举值。
@@ -781,15 +810,18 @@ class _OpenAIStreamAccumulator:
self, self,
reasoning_parse_mode: ReasoningParseMode, reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode, tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> None: ) -> None:
"""初始化累积器。 """初始化累积器。
Args: Args:
reasoning_parse_mode: 推理内容解析模式。 reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。 tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
""" """
self.reasoning_parse_mode = reasoning_parse_mode self.reasoning_parse_mode = reasoning_parse_mode
self.tool_argument_parse_mode = tool_argument_parse_mode self.tool_argument_parse_mode = tool_argument_parse_mode
self.reasoning_key = reasoning_key
self.reasoning_buffer = io.StringIO() self.reasoning_buffer = io.StringIO()
self.content_buffer = io.StringIO() self.content_buffer = io.StringIO()
self.tool_call_states: Dict[int, _StreamedToolCallState] = {} self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
@@ -825,8 +857,8 @@ class _OpenAIStreamAccumulator:
Args: Args:
delta: 当前增量对象。 delta: 当前增量对象。
""" """
native_reasoning = getattr(delta, "reasoning_content", None) native_reasoning = _extract_reasoning_content(delta, self.reasoning_key)
if isinstance(native_reasoning, str) and native_reasoning: if native_reasoning is not None:
self._using_native_reasoning = True self._using_native_reasoning = True
if self.reasoning_parse_mode != ReasoningParseMode.NONE: if self.reasoning_parse_mode != ReasoningParseMode.NONE:
self.reasoning_buffer.write(native_reasoning) self.reasoning_buffer.write(native_reasoning)
@@ -929,6 +961,7 @@ async def _default_stream_response_handler(
*, *,
reasoning_parse_mode: ReasoningParseMode, reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode, tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> Tuple[APIResponse, UsageTuple | None]: ) -> Tuple[APIResponse, UsageTuple | None]:
"""处理 OpenAI 兼容流式响应。 """处理 OpenAI 兼容流式响应。
@@ -937,6 +970,7 @@ async def _default_stream_response_handler(
interrupt_flag: 外部中断标记。 interrupt_flag: 外部中断标记。
reasoning_parse_mode: 推理内容解析模式。 reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。 tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
Returns: Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
@@ -944,6 +978,7 @@ async def _default_stream_response_handler(
accumulator = _OpenAIStreamAccumulator( accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=reasoning_parse_mode, reasoning_parse_mode=reasoning_parse_mode,
tool_argument_parse_mode=tool_argument_parse_mode, tool_argument_parse_mode=tool_argument_parse_mode,
reasoning_key=reasoning_key,
) )
usage_record: UsageTuple | None = None usage_record: UsageTuple | None = None
@@ -977,6 +1012,7 @@ def _default_normal_response_parser(
*, *,
reasoning_parse_mode: ReasoningParseMode, reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode, tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> Tuple[APIResponse, UsageTuple | None]: ) -> Tuple[APIResponse, UsageTuple | None]:
"""解析 OpenAI 兼容的非流式响应。 """解析 OpenAI 兼容的非流式响应。
@@ -984,6 +1020,7 @@ def _default_normal_response_parser(
resp: OpenAI SDK 返回的聊天补全响应。 resp: OpenAI SDK 返回的聊天补全响应。
reasoning_parse_mode: 推理内容解析模式。 reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。 tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
Returns: Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。 Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
@@ -997,10 +1034,10 @@ def _default_normal_response_parser(
api_response = APIResponse() api_response = APIResponse()
message_part = choices[0].message message_part = choices[0].message
native_reasoning = getattr(message_part, "reasoning_content", None) native_reasoning = _extract_reasoning_content(message_part, reasoning_key)
message_content = message_part.content if isinstance(message_part.content, str) else None message_content = message_part.content if isinstance(message_part.content, str) else None
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE: if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE:
api_response.reasoning_content = native_reasoning api_response.reasoning_content = native_reasoning
api_response.content = message_content api_response.content = message_content
elif isinstance(message_content, str) and message_content: elif isinstance(message_content, str) and message_content:
@@ -1046,6 +1083,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
client: AsyncOpenAI client: AsyncOpenAI
reasoning_parse_mode: ReasoningParseMode reasoning_parse_mode: ReasoningParseMode
reasoning_key: str
tool_argument_parse_mode: ToolArgumentParseMode tool_argument_parse_mode: ToolArgumentParseMode
def __init__(self, api_provider: APIProvider) -> None: def __init__(self, api_provider: APIProvider) -> None:
@@ -1057,6 +1095,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
super().__init__(api_provider) super().__init__(api_provider)
client_config = build_openai_compatible_client_config(api_provider) client_config = build_openai_compatible_client_config(api_provider)
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode) self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
self.reasoning_key = _build_reasoning_key(api_provider)
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode) self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
self.client = AsyncOpenAI( self.client = AsyncOpenAI(
api_key=client_config.api_key, api_key=client_config.api_key,
@@ -1093,6 +1132,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
flag, flag,
reasoning_parse_mode=self.reasoning_parse_mode, reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode,
reasoning_key=self.reasoning_key,
) )
return default_stream_handler return default_stream_handler
@@ -1119,6 +1159,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
response, response,
reasoning_parse_mode=self.reasoning_parse_mode, reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode, tool_argument_parse_mode=self.tool_argument_parse_mode,
reasoning_key=self.reasoning_key,
) )
return default_response_parser return default_response_parser