diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index ddd9f72c..2a3bf8a4 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -79,6 +79,25 @@ THINK_CONTENT_PATTERN = re.compile( ) """用于解析 `` 推理块的正则表达式。""" +XML_TOOL_CALL_PATTERN = re.compile(r"\s*(?P.*?)\s*", re.DOTALL | re.IGNORECASE) +"""用于兜底解析模型以 XML 文本返回的工具调用。 + +这是一个暂时性兼容方案,专门处理“思维链内容里夹带工具调用”的情况; +后续如果上游稳定返回标准 tool_calls 字段,这里可能会调整或移除。 +""" + +XML_FUNCTION_CALL_PATTERN = re.compile( + r"[A-Za-z0-9_.-]+)>\s*(?P.*?)\s*", + re.DOTALL | re.IGNORECASE, +) +"""用于从 XML 风格工具调用块中提取函数名与参数。""" + +XML_PARAMETER_PATTERN = re.compile( + r"[A-Za-z0-9_.-]+)>\s*(?P.*?)\s*", + re.DOTALL | re.IGNORECASE, +) +"""用于从 XML 风格工具调用块中提取参数列表。""" + CHAT_COMPLETIONS_RESERVED_EXTRA_BODY_KEYS = { "max_tokens", "messages", @@ -541,6 +560,66 @@ def _extract_reasoning_and_content( return None, match.group("content_only").strip() or None +def _extract_xml_tool_calls( + raw_text: str | None, + parse_mode: ToolArgumentParseMode, + response: Any, +) -> Tuple[str | None, List[ToolCall] | None]: + """从 XML 风格文本中兜底提取工具调用。""" + if not isinstance(raw_text, str) or not raw_text.strip(): + return raw_text, None + + tool_calls: List[ToolCall] = [] + + def _coerce_xml_parameter_value(raw_value: str) -> Any: + normalized_value = raw_value.strip() + if not normalized_value: + return "" + lowered_value = normalized_value.lower() + if lowered_value == "true": + return True + if lowered_value == "false": + return False + if lowered_value in {"null", "none"}: + return None + if normalized_value.startswith(("{", "[")): + try: + return repair_json(normalized_value, return_objects=True, logging=False) + except Exception: + return normalized_value + return normalized_value + + def _parse_xml_parameters(raw_arguments: str) -> Dict[str, Any] | None: + parameters = { + match.group("name").strip(): _coerce_xml_parameter_value(match.group("value")) + for match in XML_PARAMETER_PATTERN.finditer(raw_arguments) + } + return parameters or None + + def _replace_tool_call(match: re.Match[str]) -> str: + body = match.group("body") + function_match = XML_FUNCTION_CALL_PATTERN.search(body) + if function_match is None: + return match.group(0) + + function_name = function_match.group("name").strip() + raw_arguments = function_match.group("arguments").strip() + arguments = _parse_xml_parameters(raw_arguments) + if arguments is None: + arguments = _parse_tool_arguments(raw_arguments, parse_mode, response) if raw_arguments else {} + tool_calls.append( + ToolCall( + call_id=f"xml_tool_call_{len(tool_calls) + 1}", + func_name=function_name, + args=arguments, + ) + ) + return "" + + cleaned_text = XML_TOOL_CALL_PATTERN.sub(_replace_tool_call, raw_text).strip() or None + return cleaned_text, tool_calls or None + + def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> None: """记录因长度截断导致的告警日志。 @@ -552,6 +631,38 @@ def _log_length_truncation(finish_reason: str | None, model_name: str | None) -> logger.info(f"模型{model_name or ''}因为超过最大 max_token 限制,可能仅输出部分内容,可视情况调整") +def _apply_xml_tool_call_fallback( + response: APIResponse, + parse_mode: ToolArgumentParseMode, + raw_response: Any, +) -> None: + """当上游未返回标准 tool_calls 时,尝试从 XML 文本兜底解析。 + + 这是一个暂时性处理方法,用来兼容思维链中混入工具调用的返回格式, + 后续可能随着模型或上游接口的规范化而变更。 + """ + if response.tool_calls: + return + + reasoning_content, tool_calls = _extract_xml_tool_calls(response.reasoning_content, parse_mode, raw_response) + if reasoning_content != response.reasoning_content: + response.reasoning_content = reasoning_content + if tool_calls: + response.tool_calls = tool_calls + if not response.content and reasoning_content: + response.content = reasoning_content + response.reasoning_content = None + logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用") + return + + cleaned_content, tool_calls = _extract_xml_tool_calls(response.content, parse_mode, raw_response) + if cleaned_content != response.content: + response.content = cleaned_content + if tool_calls: + response.tool_calls = tool_calls + logger.warning("OpenAI 兼容响应未返回标准 tool_calls,已从 XML 文本兜底解析工具调用") + + def _coerce_openai_argument(value: Any) -> Any | Omit: """将可选参数转换为 OpenAI SDK 期望的值。 @@ -748,6 +859,7 @@ class _OpenAIStreamAccumulator: response.tool_calls.append(ToolCall(call_id=call_id, func_name=state.function_name, args=arguments)) response.raw_data = {"model": self.model_name} if self.model_name else None + _apply_xml_tool_call_fallback(response, self.tool_argument_parse_mode, response.raw_data) if not response.content and not response.tool_calls: raise EmptyResponseException(response.raw_data) @@ -873,6 +985,7 @@ def _default_normal_response_parser( finish_reason = getattr(resp.choices[0], "finish_reason", None) _log_length_truncation(finish_reason, getattr(resp, "model", None)) + _apply_xml_tool_call_fallback(api_response, tool_argument_parse_mode, resp) if not api_response.content and not api_response.tool_calls: raise EmptyResponseException(resp)