@@ -1,4 +1,12 @@
|
|||||||
from src.llm_models.model_client.openai_client import _sanitize_messages_for_toolless_request
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from src.config.model_configs import APIProvider, ReasoningParseMode, ToolArgumentParseMode
|
||||||
|
from src.llm_models.model_client.openai_client import (
|
||||||
|
_OpenAIStreamAccumulator,
|
||||||
|
_build_reasoning_key,
|
||||||
|
_default_normal_response_parser,
|
||||||
|
_sanitize_messages_for_toolless_request,
|
||||||
|
)
|
||||||
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
|
from src.llm_models.payload_content.message import Message, RoleType, TextMessagePart
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
|
||||||
@@ -25,3 +33,89 @@ def test_sanitize_messages_for_toolless_request_drops_assistant_tool_call_withou
|
|||||||
|
|
||||||
assert len(sanitized_messages) == 1
|
assert len(sanitized_messages) == 1
|
||||||
assert sanitized_messages[0].role == RoleType.User
|
assert sanitized_messages[0].role == RoleType.User
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_response_parser_ignores_reasoning_field_for_non_openrouter_provider() -> None:
|
||||||
|
response = SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
finish_reason="stop",
|
||||||
|
message=SimpleNamespace(
|
||||||
|
content="正式回复",
|
||||||
|
reasoning="推理内容",
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=None,
|
||||||
|
model="openrouter/test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_response, usage_record = _default_normal_response_parser(
|
||||||
|
response,
|
||||||
|
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||||
|
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||||
|
reasoning_key=_build_reasoning_key(
|
||||||
|
APIProvider(name="test", base_url="https://openrouter.ai.example.com/api/v1", api_key="test")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert api_response.content == "正式回复"
|
||||||
|
assert api_response.reasoning_content is None
|
||||||
|
assert usage_record is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_response_parser_reads_provider_reasoning_field_for_reasoning_domains() -> None:
|
||||||
|
provider_urls = [
|
||||||
|
"https://openrouter.ai/compatible-api",
|
||||||
|
"https://api.groq.com/openai/v1",
|
||||||
|
]
|
||||||
|
|
||||||
|
for provider_url in provider_urls:
|
||||||
|
response = SimpleNamespace(
|
||||||
|
choices=[
|
||||||
|
SimpleNamespace(
|
||||||
|
finish_reason="stop",
|
||||||
|
message=SimpleNamespace(
|
||||||
|
content="正式回复",
|
||||||
|
reasoning="推理内容",
|
||||||
|
tool_calls=None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=None,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_response, usage_record = _default_normal_response_parser(
|
||||||
|
response,
|
||||||
|
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||||
|
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||||
|
reasoning_key=_build_reasoning_key(
|
||||||
|
APIProvider(name="reasoning-provider", base_url=provider_url, api_key="test")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert api_response.content == "正式回复"
|
||||||
|
assert api_response.reasoning_content == "推理内容"
|
||||||
|
assert usage_record is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_accumulator_reads_openrouter_reasoning_delta_field() -> None:
|
||||||
|
accumulator = _OpenAIStreamAccumulator(
|
||||||
|
reasoning_parse_mode=ReasoningParseMode.AUTO,
|
||||||
|
tool_argument_parse_mode=ToolArgumentParseMode.AUTO,
|
||||||
|
reasoning_key=_build_reasoning_key(
|
||||||
|
APIProvider(name="openrouter", base_url="https://openrouter.ai/compatible-api", api_key="test")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
accumulator.process_delta(SimpleNamespace(reasoning="流式推理", content=None, tool_calls=None))
|
||||||
|
accumulator.process_delta(SimpleNamespace(content="正式回复", tool_calls=None))
|
||||||
|
|
||||||
|
api_response = accumulator.build_response()
|
||||||
|
finally:
|
||||||
|
accumulator.close()
|
||||||
|
|
||||||
|
assert api_response.content == "正式回复"
|
||||||
|
assert api_response.reasoning_content == "流式推理"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
@@ -123,6 +124,12 @@ OpenAIStreamResponseHandler = Callable[
|
|||||||
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
|
OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple | None]]
|
||||||
"""OpenAI 非流式响应解析函数类型。"""
|
"""OpenAI 非流式响应解析函数类型。"""
|
||||||
|
|
||||||
|
PROVIDER_REASONING_KEYS_BY_DOMAIN: Dict[str, str] = {
|
||||||
|
"api.groq.com": "reasoning",
|
||||||
|
"openrouter.ai": "reasoning",
|
||||||
|
}
|
||||||
|
"""按 provider 域名指定的原生推理字段名。"""
|
||||||
|
|
||||||
|
|
||||||
def _build_debug_provider_request_filename(model_name: str) -> str:
|
def _build_debug_provider_request_filename(model_name: str) -> str:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
@@ -156,6 +163,28 @@ def _build_fallback_tool_call_id(prefix: str) -> str:
|
|||||||
return f"{normalized_prefix}_{uuid4().hex}"
|
return f"{normalized_prefix}_{uuid4().hex}"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_reasoning_key(api_provider: APIProvider) -> str:
|
||||||
|
"""根据 provider 构建可读取的原生推理字段名。"""
|
||||||
|
normalized_base_url = api_provider.base_url.strip()
|
||||||
|
if not normalized_base_url:
|
||||||
|
return "reasoning_content"
|
||||||
|
parsed_url = urlparse(normalized_base_url if "://" in normalized_base_url else f"//{normalized_base_url}")
|
||||||
|
provider_hostname = (parsed_url.hostname or "").lower()
|
||||||
|
return PROVIDER_REASONING_KEYS_BY_DOMAIN.get(provider_hostname, "reasoning_content")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_reasoning_content(message_part: Any, reasoning_key: str) -> str | None:
|
||||||
|
"""从 OpenAI 兼容响应对象中读取原生推理内容。
|
||||||
|
|
||||||
|
不同兼容服务商对推理字段命名并不完全一致。这里集中处理字段访问,
|
||||||
|
避免解析路径里散落 provider 特判;具体字段名由 provider 决定。
|
||||||
|
"""
|
||||||
|
native_reasoning = getattr(message_part, reasoning_key, None)
|
||||||
|
if isinstance(native_reasoning, str) and native_reasoning:
|
||||||
|
return native_reasoning
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
def _normalize_reasoning_parse_mode(parse_mode: str | ReasoningParseMode) -> ReasoningParseMode:
|
||||||
"""将配置中的推理解析模式收敛为枚举值。
|
"""将配置中的推理解析模式收敛为枚举值。
|
||||||
|
|
||||||
@@ -781,15 +810,18 @@ class _OpenAIStreamAccumulator:
|
|||||||
self,
|
self,
|
||||||
reasoning_parse_mode: ReasoningParseMode,
|
reasoning_parse_mode: ReasoningParseMode,
|
||||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||||
|
reasoning_key: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化累积器。
|
"""初始化累积器。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reasoning_parse_mode: 推理内容解析模式。
|
reasoning_parse_mode: 推理内容解析模式。
|
||||||
tool_argument_parse_mode: 工具参数解析模式。
|
tool_argument_parse_mode: 工具参数解析模式。
|
||||||
|
reasoning_key: 允许读取的原生推理字段名。
|
||||||
"""
|
"""
|
||||||
self.reasoning_parse_mode = reasoning_parse_mode
|
self.reasoning_parse_mode = reasoning_parse_mode
|
||||||
self.tool_argument_parse_mode = tool_argument_parse_mode
|
self.tool_argument_parse_mode = tool_argument_parse_mode
|
||||||
|
self.reasoning_key = reasoning_key
|
||||||
self.reasoning_buffer = io.StringIO()
|
self.reasoning_buffer = io.StringIO()
|
||||||
self.content_buffer = io.StringIO()
|
self.content_buffer = io.StringIO()
|
||||||
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
|
self.tool_call_states: Dict[int, _StreamedToolCallState] = {}
|
||||||
@@ -825,8 +857,8 @@ class _OpenAIStreamAccumulator:
|
|||||||
Args:
|
Args:
|
||||||
delta: 当前增量对象。
|
delta: 当前增量对象。
|
||||||
"""
|
"""
|
||||||
native_reasoning = getattr(delta, "reasoning_content", None)
|
native_reasoning = _extract_reasoning_content(delta, self.reasoning_key)
|
||||||
if isinstance(native_reasoning, str) and native_reasoning:
|
if native_reasoning is not None:
|
||||||
self._using_native_reasoning = True
|
self._using_native_reasoning = True
|
||||||
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
|
if self.reasoning_parse_mode != ReasoningParseMode.NONE:
|
||||||
self.reasoning_buffer.write(native_reasoning)
|
self.reasoning_buffer.write(native_reasoning)
|
||||||
@@ -929,6 +961,7 @@ async def _default_stream_response_handler(
|
|||||||
*,
|
*,
|
||||||
reasoning_parse_mode: ReasoningParseMode,
|
reasoning_parse_mode: ReasoningParseMode,
|
||||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||||
|
reasoning_key: str,
|
||||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||||
"""处理 OpenAI 兼容流式响应。
|
"""处理 OpenAI 兼容流式响应。
|
||||||
|
|
||||||
@@ -937,6 +970,7 @@ async def _default_stream_response_handler(
|
|||||||
interrupt_flag: 外部中断标记。
|
interrupt_flag: 外部中断标记。
|
||||||
reasoning_parse_mode: 推理内容解析模式。
|
reasoning_parse_mode: 推理内容解析模式。
|
||||||
tool_argument_parse_mode: 工具参数解析模式。
|
tool_argument_parse_mode: 工具参数解析模式。
|
||||||
|
reasoning_key: 允许读取的原生推理字段名。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||||||
@@ -944,6 +978,7 @@ async def _default_stream_response_handler(
|
|||||||
accumulator = _OpenAIStreamAccumulator(
|
accumulator = _OpenAIStreamAccumulator(
|
||||||
reasoning_parse_mode=reasoning_parse_mode,
|
reasoning_parse_mode=reasoning_parse_mode,
|
||||||
tool_argument_parse_mode=tool_argument_parse_mode,
|
tool_argument_parse_mode=tool_argument_parse_mode,
|
||||||
|
reasoning_key=reasoning_key,
|
||||||
)
|
)
|
||||||
usage_record: UsageTuple | None = None
|
usage_record: UsageTuple | None = None
|
||||||
|
|
||||||
@@ -977,6 +1012,7 @@ def _default_normal_response_parser(
|
|||||||
*,
|
*,
|
||||||
reasoning_parse_mode: ReasoningParseMode,
|
reasoning_parse_mode: ReasoningParseMode,
|
||||||
tool_argument_parse_mode: ToolArgumentParseMode,
|
tool_argument_parse_mode: ToolArgumentParseMode,
|
||||||
|
reasoning_key: str,
|
||||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||||
"""解析 OpenAI 兼容的非流式响应。
|
"""解析 OpenAI 兼容的非流式响应。
|
||||||
|
|
||||||
@@ -984,6 +1020,7 @@ def _default_normal_response_parser(
|
|||||||
resp: OpenAI SDK 返回的聊天补全响应。
|
resp: OpenAI SDK 返回的聊天补全响应。
|
||||||
reasoning_parse_mode: 推理内容解析模式。
|
reasoning_parse_mode: 推理内容解析模式。
|
||||||
tool_argument_parse_mode: 工具参数解析模式。
|
tool_argument_parse_mode: 工具参数解析模式。
|
||||||
|
reasoning_key: 允许读取的原生推理字段名。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
Tuple[APIResponse, UsageTuple | None]: 解析后的响应与 usage 统计。
|
||||||
@@ -997,10 +1034,10 @@ def _default_normal_response_parser(
|
|||||||
|
|
||||||
api_response = APIResponse()
|
api_response = APIResponse()
|
||||||
message_part = choices[0].message
|
message_part = choices[0].message
|
||||||
native_reasoning = getattr(message_part, "reasoning_content", None)
|
native_reasoning = _extract_reasoning_content(message_part, reasoning_key)
|
||||||
message_content = message_part.content if isinstance(message_part.content, str) else None
|
message_content = message_part.content if isinstance(message_part.content, str) else None
|
||||||
|
|
||||||
if isinstance(native_reasoning, str) and native_reasoning and reasoning_parse_mode != ReasoningParseMode.NONE:
|
if native_reasoning is not None and reasoning_parse_mode != ReasoningParseMode.NONE:
|
||||||
api_response.reasoning_content = native_reasoning
|
api_response.reasoning_content = native_reasoning
|
||||||
api_response.content = message_content
|
api_response.content = message_content
|
||||||
elif isinstance(message_content, str) and message_content:
|
elif isinstance(message_content, str) and message_content:
|
||||||
@@ -1046,6 +1083,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
|
|
||||||
client: AsyncOpenAI
|
client: AsyncOpenAI
|
||||||
reasoning_parse_mode: ReasoningParseMode
|
reasoning_parse_mode: ReasoningParseMode
|
||||||
|
reasoning_key: str
|
||||||
tool_argument_parse_mode: ToolArgumentParseMode
|
tool_argument_parse_mode: ToolArgumentParseMode
|
||||||
|
|
||||||
def __init__(self, api_provider: APIProvider) -> None:
|
def __init__(self, api_provider: APIProvider) -> None:
|
||||||
@@ -1057,6 +1095,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
super().__init__(api_provider)
|
super().__init__(api_provider)
|
||||||
client_config = build_openai_compatible_client_config(api_provider)
|
client_config = build_openai_compatible_client_config(api_provider)
|
||||||
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
|
self.reasoning_parse_mode = _normalize_reasoning_parse_mode(api_provider.reasoning_parse_mode)
|
||||||
|
self.reasoning_key = _build_reasoning_key(api_provider)
|
||||||
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
|
self.tool_argument_parse_mode = _normalize_tool_argument_parse_mode(api_provider.tool_argument_parse_mode)
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(
|
||||||
api_key=client_config.api_key,
|
api_key=client_config.api_key,
|
||||||
@@ -1093,6 +1132,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
flag,
|
flag,
|
||||||
reasoning_parse_mode=self.reasoning_parse_mode,
|
reasoning_parse_mode=self.reasoning_parse_mode,
|
||||||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||||||
|
reasoning_key=self.reasoning_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return default_stream_handler
|
return default_stream_handler
|
||||||
@@ -1119,6 +1159,7 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
response,
|
response,
|
||||||
reasoning_parse_mode=self.reasoning_parse_mode,
|
reasoning_parse_mode=self.reasoning_parse_mode,
|
||||||
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
tool_argument_parse_mode=self.tool_argument_parse_mode,
|
||||||
|
reasoning_key=self.reasoning_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
return default_response_parser
|
return default_response_parser
|
||||||
|
|||||||
Reference in New Issue
Block a user