fix: override plugin tool stream context and remove repository tests

- force plugin tools to use runtime stream_id/chat_id from execution context
- remove repository test assets and vitest config
- document that temporary test files must be deleted after use
This commit is contained in:
Losita
2026-05-12 22:36:32 +08:00
parent 702316ae57
commit 8d0f6d4401
98 changed files with 4 additions and 30458 deletions

View File

@@ -1,459 +0,0 @@
import argparse
import asyncio
import os
import sys
import time
import json
import importlib
from dataclasses import dataclass
from typing import Optional, Dict, Any
from datetime import datetime
# 强制使用 utf-8避免控制台编码报错
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
# 确保能导入 src.*
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
try:
from maim_message import ChatStream, UserInfo, GroupInfo
except Exception:
@dataclass
class ChatStream:
stream_id: str
platform: str
user_info: UserInfo
group_info: GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
try:
import src.chat.replyer.group_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
try:
import src.chat.replyer.private_generator # noqa: F401
except (ImportError, AttributeError):
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
memory_retrieval_module._process_single_question,
)
except (ImportError, AttributeError) as e:
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
raise
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
"""创建一个测试用的 ChatStream 对象"""
user_info = UserInfo(
platform="test",
user_id="test_user",
user_nickname="测试用户",
)
group_info = GroupInfo(
platform="test",
group_id="test_group",
group_name="测试群组",
)
return ChatStream(
stream_id=chat_id,
platform="test",
user_info=user_info,
group_info=group_info,
)
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
.where(
(LLMUsage.timestamp >= start_datetime)
& (
(LLMUsage.request_type.like("%memory%"))
| (LLMUsage.request_type == "memory.question")
| (LLMUsage.request_type == "memory.react")
| (LLMUsage.request_type == "memory.react.final")
)
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
model_usage[model_name] = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0.0,
"request_count": 0,
}
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
"total_tokens": total_tokens,
"total_cost": total_cost,
"request_count": request_count,
"model_usage": model_usage,
}
except Exception as e:
logger.error(f"获取token使用情况失败: {e}")
return {
"total_prompt_tokens": 0,
"total_completion_tokens": 0,
"total_tokens": 0,
"total_cost": 0.0,
"request_count": 0,
"model_usage": {},
}
def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
obs_str = str(obs)[:200]
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
async def test_memory_retrieval(
question: str,
chat_id: str = "test_memory_retrieval",
context: str = "",
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
"found_answer": found_answer,
"answer": answer,
"is_timeout": is_timeout,
"elapsed_time": elapsed_time,
"thinking_steps": thinking_steps,
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
def main() -> None:
parser = argparse.ArgumentParser(
description="测试记忆检索功能。可以输入一个问题脚本会使用记忆检索的逻辑进行检索并记录迭代信息、时间和token总消耗。"
)
parser.add_argument(
"--chat-id",
default="test_memory_retrieval",
help="测试用的聊天ID默认: test_memory_retrieval",
)
parser.add_argument(
"--context",
default="",
help="上下文信息(可选)",
)
parser.add_argument(
"--output",
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
print("=" * 80)
question = input("\n请输入要查询的问题: ").strip()
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
if max_iterations_input:
try:
max_iterations = int(max_iterations_input)
if max_iterations <= 0:
print("警告: 迭代次数必须大于0将使用配置默认值")
max_iterations = None
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
except Exception as e:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
test_memory_retrieval(
question=question,
chat_id=args.chat_id,
context=args.context,
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
output_result = result.copy()
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
logger.error(f"测试失败: {e}", exc_info=True)
print(f"\n[错误] 测试失败: {e}")
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
main()

View File

@@ -1,845 +0,0 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ToolCallCase:
"""Tool call 参数测试用例。"""
name: str
description: str
tool_definition: Dict[str, Any]
expected_arguments: Dict[str, Any]
@property
def tool_name(self) -> str:
"""返回工具名称。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
return str(function_definition.get("name", "") or "")
return str(self.tool_definition.get("name", "") or "")
@property
def parameters_schema(self) -> Dict[str, Any]:
"""返回参数 Schema。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
parameters = function_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
parameters = self.tool_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
def build_messages(self) -> List[Dict[str, Any]]:
"""构造测试消息。"""
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
system_prompt = (
"你正在执行严格的工具调用参数兼容性测试。"
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
)
user_prompt = (
f"请立刻调用工具 `{self.tool_name}`。\n"
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
"不要输出任何解释文本,只返回工具调用。\n"
f"{expected_json}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ProbeResult:
"""单次测试结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str]
warnings: List[str]
response_text: str
reasoning_text: str
tool_calls: List[Dict[str, Any]]
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切到 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool 定义。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_default_cases() -> List[ToolCallCase]:
"""构造默认测试用例。"""
simple_expected_arguments = {
"request_id": "probe-simple-001",
"count": 7,
"enabled": True,
"mode": "strict",
"ratio": 2.5,
}
simple_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"count": {"type": "integer", "description": "数量"},
"enabled": {"type": "boolean", "description": "是否启用"},
"mode": {
"type": "string",
"description": "模式",
"enum": ["strict", "loose"],
},
"ratio": {"type": "number", "description": "比例"},
},
"required": ["request_id", "count", "enabled", "mode", "ratio"],
"additionalProperties": False,
}
nested_expected_arguments = {
"request_id": "probe-nested-001",
"notify": False,
"profile": {
"channel": "stable",
"priority": 2,
},
"tags": ["alpha", "beta", "gamma"],
"items": [
{"count": 2, "name": "apple"},
{"count": 5, "name": "banana"},
],
}
nested_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"notify": {"type": "boolean", "description": "是否通知"},
"profile": {
"type": "object",
"description": "配置对象",
"properties": {
"channel": {"type": "string", "description": "渠道"},
"priority": {"type": "integer", "description": "优先级"},
},
"required": ["channel", "priority"],
"additionalProperties": False,
},
"tags": {
"type": "array",
"description": "标签列表",
"items": {"type": "string"},
},
"items": {
"type": "array",
"description": "条目列表",
"items": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "数量"},
"name": {"type": "string", "description": "名称"},
},
"required": ["count", "name"],
"additionalProperties": False,
},
},
},
"required": ["request_id", "notify", "profile", "tags", "items"],
"additionalProperties": False,
}
return [
ToolCallCase(
name="simple",
description="标量参数类型校验",
tool_definition=_build_function_tool(
name="record_simple_probe",
description="记录简单参数探测结果",
parameters=simple_parameters,
),
expected_arguments=simple_expected_arguments,
),
ToolCallCase(
name="nested",
description="嵌套对象与数组参数校验",
tool_definition=_build_function_tool(
name="record_nested_probe",
description="记录嵌套参数探测结果",
parameters=nested_parameters,
),
expected_arguments=nested_expected_arguments,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名称到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
"""序列化工具调用结果。"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def _is_integer_value(value: Any) -> bool:
"""判断是否为整数类型且排除布尔值。"""
return isinstance(value, int) and not isinstance(value, bool)
def _is_number_value(value: Any) -> bool:
"""判断是否为数值类型且排除布尔值。"""
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
def _schema_type(schema: Dict[str, Any]) -> str:
"""解析 Schema 的类型。"""
schema_type = str(schema.get("type", "") or "").strip()
if schema_type:
return schema_type
if "properties" in schema or "required" in schema:
return "object"
return ""
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
"""按简化 JSON Schema 校验工具参数。"""
errors: List[str] = []
schema_type = _schema_type(schema)
if "enum" in schema and actual_value not in schema["enum"]:
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
if schema_type == "string":
if not isinstance(actual_value, str):
errors.append(f"{path} 类型错误,期望 string实际为 {type(actual_value).__name__}")
return errors
if schema_type == "integer":
if not _is_integer_value(actual_value):
errors.append(f"{path} 类型错误,期望 integer实际为 {type(actual_value).__name__}")
return errors
if schema_type == "number":
if not _is_number_value(actual_value):
errors.append(f"{path} 类型错误,期望 number实际为 {type(actual_value).__name__}")
return errors
if schema_type == "boolean":
if not isinstance(actual_value, bool):
errors.append(f"{path} 类型错误,期望 boolean实际为 {type(actual_value).__name__}")
return errors
if schema_type == "array":
if not isinstance(actual_value, list):
errors.append(f"{path} 类型错误,期望 array实际为 {type(actual_value).__name__}")
return errors
item_schema = schema.get("items")
if isinstance(item_schema, dict):
for index, item in enumerate(actual_value):
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
return errors
if schema_type == "object":
if not isinstance(actual_value, dict):
errors.append(f"{path} 类型错误,期望 object实际为 {type(actual_value).__name__}")
return errors
properties = schema.get("properties", {})
required_fields = [str(item) for item in schema.get("required", [])]
for required_field in required_fields:
if required_field not in actual_value:
errors.append(f"{path}.{required_field} 缺少必填字段")
for field_name, field_value in actual_value.items():
field_path = f"{path}.{field_name}"
field_schema = properties.get(field_name)
if isinstance(field_schema, dict):
errors.extend(_validate_schema(field_schema, field_value, field_path))
continue
additional_properties = schema.get("additionalProperties", True)
if additional_properties is False:
errors.append(f"{field_path} 是未定义字段")
elif isinstance(additional_properties, dict):
errors.extend(_validate_schema(additional_properties, field_value, field_path))
return errors
return errors
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
"""递归比较实际值与期望值是否完全一致。"""
errors: List[str] = []
if isinstance(expected_value, dict):
if not isinstance(actual_value, dict):
return [f"{path} 值不一致,期望 object实际为 {type(actual_value).__name__}"]
expected_keys = set(expected_value.keys())
actual_keys = set(actual_value.keys())
for missing_key in sorted(expected_keys - actual_keys):
errors.append(f"{path}.{missing_key} 缺少期望字段")
for extra_key in sorted(actual_keys - expected_keys):
errors.append(f"{path}.{extra_key} 出现了额外字段")
for shared_key in sorted(expected_keys & actual_keys):
errors.extend(
_compare_expected_values(
expected_value[shared_key],
actual_value[shared_key],
f"{path}.{shared_key}",
)
)
return errors
if isinstance(expected_value, list):
if not isinstance(actual_value, list):
return [f"{path} 值不一致,期望 array实际为 {type(actual_value).__name__}"]
if len(expected_value) != len(actual_value):
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
for index, (expected_item, actual_item) in enumerate(
zip(expected_value, actual_value, strict=False)
):
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
return errors
if isinstance(expected_value, bool):
if not isinstance(actual_value, bool) or actual_value is not expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if _is_integer_value(expected_value):
if not _is_integer_value(actual_value) or actual_value != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if isinstance(expected_value, float):
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if expected_value != actual_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
"""优先选择同名工具调用,否则回退到第一条。"""
for tool_call in tool_calls:
if tool_call.func_name == expected_tool_name:
return tool_call
return tool_calls[0]
def _validate_service_result(
service_result: LLMServiceResult,
target: ProbeTarget,
case: ToolCallCase,
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务层返回结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
return errors, warnings, serialized_tool_calls
if completion.model_name and completion.model_name != target.model_name:
errors.append(
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
)
tool_calls = completion.tool_calls or []
if not tool_calls:
errors.append("模型未返回 tool_calls")
if completion.response.strip():
warnings.append("模型返回了自然语言文本而不是工具调用")
return errors, warnings, serialized_tool_calls
if len(tool_calls) != 1:
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls预期为 1 个")
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
if selected_tool_call.func_name != case.tool_name:
errors.append(
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
)
actual_arguments = selected_tool_call.args
if not isinstance(actual_arguments, dict):
errors.append("工具参数未被解析为对象")
return errors, warnings, serialized_tool_calls
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
if completion.response.strip():
warnings.append("模型同时返回了自然语言文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
case: ToolCallCase,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次工具调用参数探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
prompt=case.build_messages(),
tool_options=[case.tool_definition],
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=case.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
"""根据参数筛选测试用例。"""
all_cases = {case.name: case for case in _build_default_cases()}
if not case_filters:
return list(all_cases.values())
selected_cases: List[ToolCallCase] = []
for case_name in case_filters:
if case_name not in all_cases:
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
selected_cases.append(all_cases[case_name])
return selected_cases
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_cases = _select_cases(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for case in selected_cases:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={case.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
case=case,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
)
parser.add_argument(
"--case",
action="append",
help="指定测试用例名,可选 simple、nested不传则运行全部默认用例",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个用例重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0 以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的自然语言文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,777 +0,0 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ToolCallScenario:
"""工具调用 API 场景定义。"""
name: str
description: str
prompt: List[Dict[str, Any]]
tool_options: List[Dict[str, Any]] | None = None
expect_tool_calls: bool | None = None
@dataclass(slots=True)
class ProbeResult:
"""单次 API 探测结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
response_text: str = ""
reasoning_text: str = ""
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切换为 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_probe_tools() -> List[Dict[str, Any]]:
"""构造通用测试工具。"""
weather_tool = _build_function_tool(
name="lookup_weather",
description="查询指定城市天气。",
parameters={
"type": "object",
"properties": {
"city": {"type": "string", "description": "城市名"},
"unit": {
"type": "string",
"description": "温度单位",
"enum": ["celsius", "fahrenheit"],
},
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
},
"required": ["city", "unit", "include_forecast"],
"additionalProperties": False,
},
)
search_tool = _build_function_tool(
name="search_docs",
description="搜索内部知识库。",
parameters={
"type": "object",
"properties": {
"query": {"type": "string", "description": "搜索关键词"},
"top_k": {"type": "integer", "description": "返回条数"},
"filters": {
"type": "object",
"description": "过滤条件",
"properties": {
"scope": {"type": "string", "description": "搜索范围"},
"tag": {"type": "string", "description": "标签"},
},
"required": ["scope", "tag"],
"additionalProperties": False,
},
},
"required": ["query", "top_k", "filters"],
"additionalProperties": False,
},
)
return [weather_tool, search_tool]
def _build_default_scenarios() -> List[ToolCallScenario]:
"""构造默认测试场景。"""
tools = _build_probe_tools()
weather_tool = tools[0]
search_tool = tools[1]
history_tool_call = {
"id": "call_hist_weather_001",
"type": "function",
"function": {
"name": "lookup_weather",
"arguments": {
"city": "上海",
"unit": "celsius",
"include_forecast": True,
},
},
}
nested_history_tool_call = {
"id": "call_hist_search_001",
"type": "function",
"function": {
"name": "search_docs",
"arguments": {
"query": "工具调用兼容性",
"top_k": 3,
"filters": {
"scope": "internal",
"tag": "tool-call",
},
},
},
}
return [
ToolCallScenario(
name="fresh_tool_call",
description="首轮普通工具调用请求。",
prompt=[
{
"role": "system",
"content": (
"你正在执行工具调用连通性测试。"
"如果能调用工具,就优先调用最合适的工具。"
),
},
{
"role": "user",
"content": "请查询上海天气,并使用工具给出参数。",
},
],
tool_options=[weather_tool],
expect_tool_calls=True,
),
ToolCallScenario(
name="history_assistant_tool_calls_with_content",
description="历史 assistant 同时包含文本和 tool_calls当前轮不再提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"content": "我先查询天气,再继续回答。",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续说,别丢掉上下文。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_assistant_tool_calls_without_content",
description="历史 assistant 只有 tool_calls没有文本内容。",
prompt=[
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
{"role": "user", "content": "先帮我查一下上海天气。"},
{
"role": "assistant",
"tool_calls": [history_tool_call],
},
{"role": "user", "content": "继续。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_result_followup",
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
prompt=[
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查询天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "多云",
"temperature_c": 24,
"forecast": ["", "小雨"],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "结合上面的查询结果继续总结。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_multiple_tool_calls_and_results",
description="历史中包含多个 tool_calls 与多条 tool 结果。",
prompt=[
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
{
"role": "assistant",
"content": "我分两步查询。",
"tool_calls": [history_tool_call, nested_history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 22,
},
ensure_ascii=False,
),
},
{
"role": "tool",
"tool_call_id": "call_hist_search_001",
"content": json.dumps(
{
"items": [
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
"部分 provider 在历史消息回放时兼容性较弱",
],
},
ensure_ascii=False,
),
},
{"role": "user", "content": "继续整合上面的两个结果。"},
],
tool_options=None,
expect_tool_calls=None,
),
ToolCallScenario(
name="history_tool_calls_with_current_tools",
description="保留历史 tool_calls同时当前轮仍然提供 tools。",
prompt=[
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
{"role": "user", "content": "先查上海天气。"},
{
"role": "assistant",
"content": "我先查天气。",
"tool_calls": [history_tool_call],
},
{
"role": "tool",
"tool_call_id": "call_hist_weather_001",
"content": json.dumps(
{
"city": "上海",
"condition": "",
"temperature_c": 26,
},
ensure_ascii=False,
),
},
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
],
tool_options=[search_tool],
expect_tool_calls=True,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
"""序列化返回中的工具调用。"""
if not tool_calls:
return []
serialized_items: List[Dict[str, Any]] = []
for tool_call in tool_calls:
serialized_items.append(
{
"id": getattr(tool_call, "call_id", ""),
"function": {
"name": getattr(tool_call, "func_name", ""),
"arguments": dict(getattr(tool_call, "args", {}) or {}),
},
**(
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
if getattr(tool_call, "extra_content", None)
else {}
),
}
)
return serialized_items
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
return errors, warnings, serialized_tool_calls
if scenario.expect_tool_calls is True and not serialized_tool_calls:
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
if scenario.expect_tool_calls is False and serialized_tool_calls:
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
if completion.response.strip():
warnings.append("模型返回了可见文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
scenario: ToolCallScenario,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次 API 探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
prompt=scenario.prompt,
tool_options=scenario.tool_options,
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=scenario.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
"""按名称筛选测试场景。"""
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
if not case_filters:
return list(all_scenarios.values())
selected_scenarios: List[ToolCallScenario] = []
for case_name in case_filters:
if case_name not in all_scenarios:
raise ValueError(
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
)
selected_scenarios.append(all_scenarios[case_name])
return selected_scenarios
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_scenarios = _select_scenarios(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for scenario in selected_scenarios:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={scenario.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
scenario=scenario,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
)
parser.add_argument(
"--case",
action="append",
help=(
"指定测试场景名,可选值包括 "
"fresh_tool_call、history_assistant_tool_calls_with_content、"
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
),
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个场景重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0,以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的可见文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
config_manager.initialize()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())