Merge pull request #1604 from Soulter/dev
fix(llm): support reasoning field for OpenRouter and Groq
This commit is contained in:
@@ -1,4 +1,12 @@
|
||||
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
||||
from src.llm_models.model_client.openai_client import (
|
||||
_OpenAIStreamAccumulator,
|
||||
_build_reasoning_key,
|
||||
_default_normal_response_parser,
|
||||
_sanitize_messages_for_toolless_request,
|
||||
)
|
||||
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
@@ -25,3 +33,89 @@ def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_withou
|
||||
|
||||
assert len(sanitized_messages) == 1
|
||||
assert sanitized_messages[0].role == RoleType.User
|
||||
|
||||
|
||||
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="stop",
|
||||
message=SimpleNamespace(
|
||||
content="正式回复",
|
||||
reasoning="推理内容",
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
model="openrouter/test-model",
|
||||
)
|
||||
|
||||
api_response, usage_record = _default_normal_response_parser(
|
||||
response,
|
||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||
reasoning_key=_build_reasoning_key(
|
||||
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
|
||||
),
|
||||
)
|
||||
|
||||
assert api_response.content == "正式回复"
|
||||
assert api_response.reasoning_content is None
|
||||
assert usage_record is None
|
||||
|
||||
|
||||
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
|
||||
provider_urls = [
|
||||
"https://openrouter.ai/compatible-api",
|
||||
"https://api.groq.com/openai/v1",
|
||||
]
|
||||
|
||||
for provider_url in provider_urls:
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="stop",
|
||||
message=SimpleNamespace(
|
||||
content="正式回复",
|
||||
reasoning="推理内容",
|
||||
tool_calls=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
api_response, usage_record = _default_normal_response_parser(
|
||||
response,
|
||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||
reasoning_key=_build_reasoning_key(
|
||||
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
|
||||
),
|
||||
)
|
||||
|
||||
assert api_response.content == "正式回复"
|
||||
assert api_response.reasoning_content == "推理内容"
|
||||
assert usage_record is None
|
||||
|
||||
|
||||
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
|
||||
accumulator = _OpenAIStreamAccumulator(
|
||||
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||
reasoning_key=_build_reasoning_key(
|
||||
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
|
||||
),
|
||||
)
|
||||
try:
|
||||
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
|
||||
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
|
||||
|
||||
api_response = accumulator.build_response()
|
||||
finally:
|
||||
accumulator.close()
|
||||
|
||||
assert api_response.content == "正式回复"
|
||||
assert api_response.reasoning_content == "流式推理"
|
||||
|
||||
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
from json_repair import repair_json
|
||||
@@ -123,6 +124,12 @@ OpenAIStreamResponseHandler = Callable[
|
||||
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
|
||||
"""OpenAI 非流式响应解析函数类型。"""
|
||||
|
||||
PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = {
|
||||
"api.groq.com": "reasoning",
|
||||
"openrouter.ai": "reasoning",
|
||||
}
|
||||
"""按 provider 域名指定的原生推理字段名。"""
|
||||
|
||||
|
||||
def _build_debug_provider_request_filename(model_name: str) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
@@ -156,6 +163,28 @@ def _build_fallback_tool_call_id(prefix: str) -> str:
|
||||
return f"{normalized_prefix}_{uuid4().hex}"
|
||||
|
||||
|
||||
def _build_reasoning_key(api_provider: APIProvider) -> str:
|
||||
"""根据 provider 构建可读取的原生推理字段名。"""
|
||||
normalized_base_url = api_provider.base_url.strip()
|
||||
if not normalized_base_url:
|
||||
return "reasoning_content"
|
||||
parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}")
|
||||
provider_hostname = (parsed_url.hostname or "").lower()
|
||||
return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content")
|
||||
|
||||
|
||||
def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None:
|
||||
"""从 OpenAI 兼容响应对象中读取原生推理内容。
|
||||
|
||||
不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问,
|
||||
避免解析路径里散落 provider 特判;具体字段名由 provider 决定。
|
||||
"""
|
||||
native_reasoning = getattr(message_part, reasoning_key, None)
|
||||
if isinstance(native_reasoning, str) and native_reasoning:
|
||||
return native_reasoning
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||||
"""将配置中的推理解析模式收敛为枚举值。
|
||||
|
||||
@@ -781,15 +810,18 @@ class _OpenAIStreamAccumulator:
|
||||
self,
|
||||
reasoning_parse_mode: ReasoningParseMode,
|
||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||
reasoning_key: str,
|
||||
) -> None:
|
||||
"""初始化累积器。
|
||||
|
||||
Args:
|
||||
reasoning_parse_mode: 推理内容解析模式。
|
||||
tool_argument_parse_mode: 工具参数解析模式。
|
||||
reasoning_key: 允许读取的原生推理字段名。
|
||||
"""
|
||||
self.reasoning_parse_mode = reasoning_parse_mode
|
||||
self.tool_argument_parse_mode = tool_argument_parse_mode
|
||||
self.reasoning_key = reasoning_key
|
||||
self.reasoning_buffer = io.StringIO()
|
||||
self.content_buffer = io.StringIO()
|
||||
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
|
||||
@@ -825,8 +857,8 @@ class _OpenAIStreamAccumulator:
|
||||
Args:
|
||||
delta: 当前增量对象。
|
||||
"""
|
||||
native_reasoning = getattr(delta, "reasoning_content", None)
|
||||
if isinstance(native_reasoning, str) and native_reasoning:
|
||||
native_reasoning = _extract_reasoning_content(delta, self.reasoning_key)
|
||||
if native_reasoning is not None:
|
||||
self._using_native_reasoning = True
|
||||
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
|
||||
self.reasoning_buffer.write(native_reasoning)
|
||||
@@ -929,6 +961,7 @@ async def _default_stream_response_handler(
|
||||
*,
|
||||
reasoning_parse_mode: ReasoningParseMode,
|
||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||
reasoning_key: str,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""处理 OpenAI 兼容流式响应。
|
||||
|
||||
@@ -937,6 +970,7 @@ async def _default_stream_response_handler(
|
||||
interrupt_flag: 外部中断标记。
|
||||
reasoning_parse_mode: 推理内容解析模式。
|
||||
tool_argument_parse_mode: 工具参数解析模式。
|
||||
reasoning_key: 允许读取的原生推理字段名。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||||
@@ -944,6 +978,7 @@ async def _default_stream_response_handler(
|
||||
accumulator = _OpenAIStreamAccumulator(
|
||||
reasoning_parse_mode=reasoning_parse_mode,
|
||||
tool_argument_parse_mode=tool_argument_parse_mode,
|
||||
reasoning_key=reasoning_key,
|
||||
)
|
||||
usage_record: UsageTuple | None = None
|
||||
|
||||
@@ -977,6 +1012,7 @@ def _default_normal_response_parser(
|
||||
*,
|
||||
reasoning_parse_mode: ReasoningParseMode,
|
||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||
reasoning_key: str,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""解析 OpenAI 兼容的非流式响应。
|
||||
|
||||
@@ -984,6 +1020,7 @@ def _default_normal_response_parser(
|
||||
resp: OpenAI SDK 返回的聊天补全响应。
|
||||
reasoning_parse_mode: 推理内容解析模式。
|
||||
tool_argument_parse_mode: 工具参数解析模式。
|
||||
reasoning_key: 允许读取的原生推理字段名。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||||
@@ -997,10 +1034,10 @@ def _default_normal_response_parser(
|
||||
|
||||
api_response = APIResponse()
|
||||
message_part = choices[0].message
|
||||
native_reasoning = getattr(message_part, "reasoning_content", None)
|
||||
native_reasoning = _extract_reasoning_content(message_part, reasoning_key)
|
||||
message_content = message_part.content if isinstance(message_part.content, str) else None
|
||||
|
||||
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE:
|
||||
if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE:
|
||||
api_response.reasoning_content = native_reasoning
|
||||
api_response.content = message_content
|
||||
elif isinstance(message_content, str) and message_content:
|
||||
@@ -1046,6 +1083,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
|
||||
client: AsyncOpenAI
|
||||
reasoning_parse_mode: ReasoningParseMode
|
||||
reasoning_key: str
|
||||
tool_argument_parse_mode: ToolArgumentParseMode
|
||||
|
||||
def __init__(self, api_provider: APIProvider) -> None:
|
||||
@@ -1057,6 +1095,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
super().__init__(api_provider)
|
||||
client_config = build_openai_compatible_client_config(api_provider)
|
||||
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
|
||||
self.reasoning_key = _build_reasoning_key(api_provider)
|
||||
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=client_config.api_key,
|
||||
@@ -1093,6 +1132,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
flag,
|
||||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||||
reasoning_key=self.reasoning_key,
|
||||
)
|
||||
|
||||
return default_stream_handler
|
||||
@@ -1119,6 +1159,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
response,
|
||||
reasoning_parse_mode=self.reasoning_parse_mode,
|
||||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||||
reasoning_key=self.reasoning_key,
|
||||
)
|
||||
|
||||
return default_response_parser
|
||||
|
||||
Reference in New Issue
Block a user