Merge pull request #1604 from Soulter/dev

fix(llm): support reasoning field for OpenRouter and Groq
This commit is contained in:
SengokuCola
2026-04-27 10:10:20 +08:00
committed by GitHub
2 changed files with 140 additions and 5 deletions

View File

@@ -1,4 +1,12 @@
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request
from types import SimpleNamespace
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
from src.llm_models.model_client.openai_client import (
_OpenAIStreamAccumulator,
_build_reasoning_key,
_default_normal_response_parser,
_sanitize_messages_for_toolless_request,
)
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
from src.llm_models.payload_content.tool_option import ToolCall
@@ -25,3 +33,89 @@ def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_withou
assert len(sanitized_messages) == 1
assert sanitized_messages[0].role == RoleType.User
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="openrouter/test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content is None
assert usage_record is None
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
provider_urls = [
"https://openrouter.ai/compatible-api",
"https://api.groq.com/openai/v1",
]
for provider_url in provider_urls:
response = SimpleNamespace(
choices=[
SimpleNamespace(
finish_reason="stop",
message=SimpleNamespace(
content="正式回复",
reasoning="推理内容",
tool_calls=None,
),
)
],
usage=None,
model="test-model",
)
api_response, usage_record = _default_normal_response_parser(
response,
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
),
)
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "推理内容"
assert usage_record is None
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=ReasoningParseMode.AUTO,
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
reasoning_key=_build_reasoning_key(
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
),
)
try:
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
api_response = accumulator.build_response()
finally:
accumulator.close()
assert api_response.content == "正式回复"
assert api_response.reasoning_content == "流式推理"

View File

@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from urllib.parse import urlparse
from uuid import uuid4
from json_repair import repair_json
@@ -123,6 +124,12 @@ OpenAIStreamResponseHandler = Callable[
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
"""OpenAI 非流式响应解析函数类型。"""
PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = {
"api.groq.com": "reasoning",
"openrouter.ai": "reasoning",
}
"""按 provider 域名指定的原生推理字段名。"""
def _build_debug_provider_request_filename(model_name: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
@@ -156,6 +163,28 @@ def _build_fallback_tool_call_id(prefix: str) -> str:
return f"{normalized_prefix}_{uuid4().hex}"
def _build_reasoning_key(api_provider: APIProvider) -> str:
"""根据 provider 构建可读取的原生推理字段名。"""
normalized_base_url = api_provider.base_url.strip()
if not normalized_base_url:
return "reasoning_content"
parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}")
provider_hostname = (parsed_url.hostname or "").lower()
return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content")
def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None:
"""从 OpenAI 兼容响应对象中读取原生推理内容。
不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问,
避免解析路径里散落 provider 特判;具体字段名由 provider 决定。
"""
native_reasoning = getattr(message_part, reasoning_key, None)
if isinstance(native_reasoning, str) and native_reasoning:
return native_reasoning
return None
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
"""将配置中的推理解析模式收敛为枚举值。
@@ -781,15 +810,18 @@ class _OpenAIStreamAccumulator:
self,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> None:
"""初始化累积器。
Args:
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
"""
self.reasoning_parse_mode = reasoning_parse_mode
self.tool_argument_parse_mode = tool_argument_parse_mode
self.reasoning_key = reasoning_key
self.reasoning_buffer = io.StringIO()
self.content_buffer = io.StringIO()
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
@@ -825,8 +857,8 @@ class _OpenAIStreamAccumulator:
Args:
delta: 当前增量对象。
"""
native_reasoning = getattr(delta, "reasoning_content", None)
if isinstance(native_reasoning, str) and native_reasoning:
native_reasoning = _extract_reasoning_content(delta, self.reasoning_key)
if native_reasoning is not None:
self._using_native_reasoning = True
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
self.reasoning_buffer.write(native_reasoning)
@@ -929,6 +961,7 @@ async def _default_stream_response_handler(
*,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> Tuple[APIResponse, UsageTuple | None]:
"""处理 OpenAI 兼容流式响应。
@@ -937,6 +970,7 @@ async def _default_stream_response_handler(
interrupt_flag: 外部中断标记。
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
@@ -944,6 +978,7 @@ async def _default_stream_response_handler(
accumulator = _OpenAIStreamAccumulator(
reasoning_parse_mode=reasoning_parse_mode,
tool_argument_parse_mode=tool_argument_parse_mode,
reasoning_key=reasoning_key,
)
usage_record: UsageTuple | None = None
@@ -977,6 +1012,7 @@ def _default_normal_response_parser(
*,
reasoning_parse_mode: ReasoningParseMode,
tool_argument_parse_mode: ToolArgumentParseMode,
reasoning_key: str,
) -> Tuple[APIResponse, UsageTuple | None]:
"""解析 OpenAI 兼容的非流式响应。
@@ -984,6 +1020,7 @@ def _default_normal_response_parser(
resp: OpenAI SDK 返回的聊天补全响应。
reasoning_parse_mode: 推理内容解析模式。
tool_argument_parse_mode: 工具参数解析模式。
reasoning_key: 允许读取的原生推理字段名。
Returns:
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
@@ -997,10 +1034,10 @@ def _default_normal_response_parser(
api_response = APIResponse()
message_part = choices[0].message
native_reasoning = getattr(message_part, "reasoning_content", None)
native_reasoning = _extract_reasoning_content(message_part, reasoning_key)
message_content = message_part.content if isinstance(message_part.content, str) else None
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE:
if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE:
api_response.reasoning_content = native_reasoning
api_response.content = message_content
elif isinstance(message_content, str) and message_content:
@@ -1046,6 +1083,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
client: AsyncOpenAI
reasoning_parse_mode: ReasoningParseMode
reasoning_key: str
tool_argument_parse_mode: ToolArgumentParseMode
def __init__(self, api_provider: APIProvider) -> None:
@@ -1057,6 +1095,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
super().__init__(api_provider)
client_config = build_openai_compatible_client_config(api_provider)
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
self.reasoning_key = _build_reasoning_key(api_provider)
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
self.client = AsyncOpenAI(
api_key=client_config.api_key,
@@ -1093,6 +1132,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
flag,
reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode,
reasoning_key=self.reasoning_key,
)
return default_stream_handler
@@ -1119,6 +1159,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
response,
reasoning_parse_mode=self.reasoning_parse_mode,
tool_argument_parse_mode=self.tool_argument_parse_mode,
reasoning_key=self.reasoning_key,
)
return default_response_parser