fix: override plugin tool stream context and remove repository tests
- force plugin tools to use runtime stream_id/chat_id from execution context - remove repository test assets and vitest config - document that temporary test files must be deleted after use
This commit is contained in:
@@ -1,459 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
# 强制使用 utf-8,避免控制台编码报错
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 确保能导入 src.*
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import initialize_logging, get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import LLMUsage
|
||||
from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
|
||||
|
||||
try:
|
||||
from maim_message import ChatStream, UserInfo, GroupInfo
|
||||
except Exception:
|
||||
@dataclass
|
||||
class ChatStream:
|
||||
stream_id: str
|
||||
platform: str
|
||||
user_info: UserInfo
|
||||
group_info: GroupInfo
|
||||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
try:
|
||||
import src.chat.replyer.group_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
try:
|
||||
import src.chat.replyer.private_generator # noqa: F401
|
||||
except (ImportError, AttributeError):
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
memory_retrieval_module._process_single_question,
|
||||
)
|
||||
except (ImportError, AttributeError) as e:
|
||||
logger.error(f"导入 memory_retrieval 模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStream:
|
||||
"""创建一个测试用的 ChatStream 对象"""
|
||||
user_info = UserInfo(
|
||||
platform="test",
|
||||
user_id="test_user",
|
||||
user_nickname="测试用户",
|
||||
)
|
||||
group_info = GroupInfo(
|
||||
platform="test",
|
||||
group_id="test_group",
|
||||
group_name="测试群组",
|
||||
)
|
||||
return ChatStream(
|
||||
stream_id=chat_id,
|
||||
platform="test",
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_datetime)
|
||||
& (
|
||||
(LLMUsage.request_type.like("%memory%"))
|
||||
| (LLMUsage.request_type == "memory.question")
|
||||
| (LLMUsage.request_type == "memory.react")
|
||||
| (LLMUsage.request_type == "memory.react.final")
|
||||
)
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
model_usage[model_name] = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"cost": 0.0,
|
||||
"request_count": 0,
|
||||
}
|
||||
model_usage[model_name]["prompt_tokens"] += record.prompt_tokens or 0
|
||||
model_usage[model_name]["completion_tokens"] += record.completion_tokens or 0
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"total_cost": total_cost,
|
||||
"request_count": request_count,
|
||||
"model_usage": model_usage,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取token使用情况失败: {e}")
|
||||
return {
|
||||
"total_prompt_tokens": 0,
|
||||
"total_completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"total_cost": 0.0,
|
||||
"request_count": 0,
|
||||
"model_usage": {},
|
||||
}
|
||||
|
||||
|
||||
def format_thinking_steps(thinking_steps: list) -> str:
|
||||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
obs_str = str(obs)[:200]
|
||||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def test_memory_retrieval(
|
||||
question: str,
|
||||
chat_id: str = "test_memory_retrieval",
|
||||
context: str = "",
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print("[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
logger.debug("记忆检索 prompt 已经初始化,跳过重复初始化")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print("\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
max_iterations=max_iterations,
|
||||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
"found_answer": found_answer,
|
||||
"answer": answer,
|
||||
"is_timeout": is_timeout,
|
||||
"elapsed_time": elapsed_time,
|
||||
"thinking_steps": thinking_steps,
|
||||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print("\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print("\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage["model_usage"]:
|
||||
print("\n[按模型统计]")
|
||||
for model_name, usage in token_usage["model_usage"].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print("\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="测试记忆检索功能。可以输入一个问题,脚本会使用记忆检索的逻辑进行检索,并记录迭代信息、时间和token总消耗。"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chat-id",
|
||||
default="test_memory_retrieval",
|
||||
help="测试用的聊天ID(默认: test_memory_retrieval)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context",
|
||||
default="",
|
||||
help="上下文信息(可选)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
print("=" * 80)
|
||||
question = input("\n请输入要查询的问题: ").strip()
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
if max_iterations_input:
|
||||
try:
|
||||
max_iterations = int(max_iterations_input)
|
||||
if max_iterations <= 0:
|
||||
print("警告: 迭代次数必须大于0,将使用配置默认值")
|
||||
max_iterations = None
|
||||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
test_memory_retrieval(
|
||||
question=question,
|
||||
chat_id=args.chat_id,
|
||||
context=args.context,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
output_result = result.copy()
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
logger.error(f"测试失败: {e}", exc_info=True)
|
||||
print(f"\n[错误] 测试失败: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,845 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
||||
from src.config.config import config_manager # noqa: E402
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
||||
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
|
||||
from src.services.llm_service import generate # noqa: E402
|
||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallCase:
|
||||
"""Tool call 参数测试用例。"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
tool_definition: Dict[str, Any]
|
||||
expected_arguments: Dict[str, Any]
|
||||
|
||||
@property
|
||||
def tool_name(self) -> str:
|
||||
"""返回工具名称。"""
|
||||
if self.tool_definition.get("type") == "function":
|
||||
function_definition = self.tool_definition.get("function", {})
|
||||
return str(function_definition.get("name", "") or "")
|
||||
return str(self.tool_definition.get("name", "") or "")
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> Dict[str, Any]:
|
||||
"""返回参数 Schema。"""
|
||||
if self.tool_definition.get("type") == "function":
|
||||
function_definition = self.tool_definition.get("function", {})
|
||||
parameters = function_definition.get("parameters", {})
|
||||
return parameters if isinstance(parameters, dict) else {}
|
||||
parameters = self.tool_definition.get("parameters", {})
|
||||
return parameters if isinstance(parameters, dict) else {}
|
||||
|
||||
def build_messages(self) -> List[Dict[str, Any]]:
|
||||
"""构造测试消息。"""
|
||||
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
|
||||
system_prompt = (
|
||||
"你正在执行严格的工具调用参数兼容性测试。"
|
||||
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
|
||||
)
|
||||
user_prompt = (
|
||||
f"请立刻调用工具 `{self.tool_name}`。\n"
|
||||
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
|
||||
"不要输出任何解释文本,只返回工具调用。\n"
|
||||
f"{expected_json}"
|
||||
)
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeTarget:
|
||||
"""单个待测试模型目标。"""
|
||||
|
||||
task_name: str
|
||||
model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeResult:
|
||||
"""单次测试结果。"""
|
||||
|
||||
task_name: str
|
||||
target_model_name: str
|
||||
actual_model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
case_name: str
|
||||
attempt: int
|
||||
success: bool
|
||||
elapsed_seconds: float
|
||||
errors: List[str]
|
||||
warnings: List[str]
|
||||
response_text: str
|
||||
reasoning_text: str
|
||||
tool_calls: List[Dict[str, Any]]
|
||||
|
||||
|
||||
def _ensure_utf8_console() -> None:
|
||||
"""尽量将控制台编码切到 UTF-8。"""
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构造 OpenAI 风格 function tool 定义。"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_default_cases() -> List[ToolCallCase]:
|
||||
"""构造默认测试用例。"""
|
||||
simple_expected_arguments = {
|
||||
"request_id": "probe-simple-001",
|
||||
"count": 7,
|
||||
"enabled": True,
|
||||
"mode": "strict",
|
||||
"ratio": 2.5,
|
||||
}
|
||||
simple_parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request_id": {"type": "string", "description": "请求 ID"},
|
||||
"count": {"type": "integer", "description": "数量"},
|
||||
"enabled": {"type": "boolean", "description": "是否启用"},
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"description": "模式",
|
||||
"enum": ["strict", "loose"],
|
||||
},
|
||||
"ratio": {"type": "number", "description": "比例"},
|
||||
},
|
||||
"required": ["request_id", "count", "enabled", "mode", "ratio"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
nested_expected_arguments = {
|
||||
"request_id": "probe-nested-001",
|
||||
"notify": False,
|
||||
"profile": {
|
||||
"channel": "stable",
|
||||
"priority": 2,
|
||||
},
|
||||
"tags": ["alpha", "beta", "gamma"],
|
||||
"items": [
|
||||
{"count": 2, "name": "apple"},
|
||||
{"count": 5, "name": "banana"},
|
||||
],
|
||||
}
|
||||
nested_parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request_id": {"type": "string", "description": "请求 ID"},
|
||||
"notify": {"type": "boolean", "description": "是否通知"},
|
||||
"profile": {
|
||||
"type": "object",
|
||||
"description": "配置对象",
|
||||
"properties": {
|
||||
"channel": {"type": "string", "description": "渠道"},
|
||||
"priority": {"type": "integer", "description": "优先级"},
|
||||
},
|
||||
"required": ["channel", "priority"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "标签列表",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"items": {
|
||||
"type": "array",
|
||||
"description": "条目列表",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer", "description": "数量"},
|
||||
"name": {"type": "string", "description": "名称"},
|
||||
},
|
||||
"required": ["count", "name"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["request_id", "notify", "profile", "tags", "items"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
return [
|
||||
ToolCallCase(
|
||||
name="simple",
|
||||
description="标量参数类型校验",
|
||||
tool_definition=_build_function_tool(
|
||||
name="record_simple_probe",
|
||||
description="记录简单参数探测结果",
|
||||
parameters=simple_parameters,
|
||||
),
|
||||
expected_arguments=simple_expected_arguments,
|
||||
),
|
||||
ToolCallCase(
|
||||
name="nested",
|
||||
description="嵌套对象与数组参数校验",
|
||||
tool_definition=_build_function_tool(
|
||||
name="record_nested_probe",
|
||||
description="记录嵌套参数探测结果",
|
||||
parameters=nested_parameters,
|
||||
),
|
||||
expected_arguments=nested_expected_arguments,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
||||
"""解析命令行中的多值参数。"""
|
||||
parsed_values: List[str] = []
|
||||
for raw_value in raw_values or []:
|
||||
for item in str(raw_value).split(","):
|
||||
normalized_item = item.strip()
|
||||
if normalized_item:
|
||||
parsed_values.append(normalized_item)
|
||||
return parsed_values
|
||||
|
||||
|
||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
||||
"""构造模型名称到模型配置的映射。"""
|
||||
return {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
|
||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
||||
"""构造 Provider 名称到配置的映射。"""
|
||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
||||
|
||||
|
||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
||||
"""选择默认任务名。"""
|
||||
if "utils" in task_names:
|
||||
return "utils"
|
||||
if not task_names:
|
||||
raise ValueError("当前没有可用的任务配置")
|
||||
return str(task_names[0])
|
||||
|
||||
|
||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
||||
"""根据命令行参数解析待测试目标。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
provider_map = _build_provider_map()
|
||||
|
||||
if not available_tasks:
|
||||
raise ValueError("未找到任何可用的模型任务配置")
|
||||
|
||||
if task_filters:
|
||||
selected_task_names = []
|
||||
for task_name in task_filters:
|
||||
if task_name not in available_tasks:
|
||||
raise ValueError(f"未找到任务 `{task_name}`")
|
||||
selected_task_names.append(task_name)
|
||||
else:
|
||||
selected_task_names = [
|
||||
task_name
|
||||
for task_name in available_tasks
|
||||
if task_name not in DEFAULT_SKIP_TASKS
|
||||
]
|
||||
|
||||
if not selected_task_names:
|
||||
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
|
||||
|
||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
||||
resolved_targets: List[ProbeTarget] = []
|
||||
seen_models: set[str] = set()
|
||||
|
||||
if model_filters:
|
||||
model_names = list(model_filters)
|
||||
else:
|
||||
model_names = []
|
||||
for task_name in selected_task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
for model_name in task_config.model_list:
|
||||
if model_name not in model_names:
|
||||
model_names.append(model_name)
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name in seen_models:
|
||||
continue
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"未找到模型 `{model_name}`")
|
||||
|
||||
target_task_name = ""
|
||||
for task_name in selected_task_names:
|
||||
if model_name in available_tasks[task_name].model_list:
|
||||
target_task_name = task_name
|
||||
break
|
||||
if not target_task_name:
|
||||
target_task_name = default_task_name
|
||||
|
||||
model_info = model_map[model_name]
|
||||
provider_info = provider_map[model_info.api_provider]
|
||||
resolved_targets.append(
|
||||
ProbeTarget(
|
||||
task_name=target_task_name,
|
||||
model_name=model_name,
|
||||
provider_name=provider_info.name,
|
||||
client_type=provider_info.client_type,
|
||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
||||
)
|
||||
)
|
||||
seen_models.add(model_name)
|
||||
|
||||
return resolved_targets
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
||||
"""临时将某个任务锁定到单模型。"""
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
task_config = getattr(model_task_config, task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
||||
|
||||
original_model_list = list(task_config.model_list)
|
||||
original_selection_strategy = task_config.selection_strategy
|
||||
task_config.model_list = [model_name]
|
||||
task_config.selection_strategy = "balance"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_config.model_list = original_model_list
|
||||
task_config.selection_strategy = original_selection_strategy
|
||||
|
||||
|
||||
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
|
||||
"""序列化工具调用结果。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
return [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": dict(tool_call.args or {}),
|
||||
},
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
|
||||
def _is_integer_value(value: Any) -> bool:
|
||||
"""判断是否为整数类型且排除布尔值。"""
|
||||
return isinstance(value, int) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _is_number_value(value: Any) -> bool:
|
||||
"""判断是否为数值类型且排除布尔值。"""
|
||||
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _schema_type(schema: Dict[str, Any]) -> str:
|
||||
"""解析 Schema 的类型。"""
|
||||
schema_type = str(schema.get("type", "") or "").strip()
|
||||
if schema_type:
|
||||
return schema_type
|
||||
if "properties" in schema or "required" in schema:
|
||||
return "object"
|
||||
return ""
|
||||
|
||||
|
||||
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
|
||||
"""按简化 JSON Schema 校验工具参数。"""
|
||||
errors: List[str] = []
|
||||
schema_type = _schema_type(schema)
|
||||
|
||||
if "enum" in schema and actual_value not in schema["enum"]:
|
||||
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
|
||||
|
||||
if schema_type == "string":
|
||||
if not isinstance(actual_value, str):
|
||||
errors.append(f"{path} 类型错误,期望 string,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "integer":
|
||||
if not _is_integer_value(actual_value):
|
||||
errors.append(f"{path} 类型错误,期望 integer,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "number":
|
||||
if not _is_number_value(actual_value):
|
||||
errors.append(f"{path} 类型错误,期望 number,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "boolean":
|
||||
if not isinstance(actual_value, bool):
|
||||
errors.append(f"{path} 类型错误,期望 boolean,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
if schema_type == "array":
|
||||
if not isinstance(actual_value, list):
|
||||
errors.append(f"{path} 类型错误,期望 array,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
item_schema = schema.get("items")
|
||||
if isinstance(item_schema, dict):
|
||||
for index, item in enumerate(actual_value):
|
||||
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
|
||||
return errors
|
||||
|
||||
if schema_type == "object":
|
||||
if not isinstance(actual_value, dict):
|
||||
errors.append(f"{path} 类型错误,期望 object,实际为 {type(actual_value).__name__}")
|
||||
return errors
|
||||
|
||||
properties = schema.get("properties", {})
|
||||
required_fields = [str(item) for item in schema.get("required", [])]
|
||||
for required_field in required_fields:
|
||||
if required_field not in actual_value:
|
||||
errors.append(f"{path}.{required_field} 缺少必填字段")
|
||||
|
||||
for field_name, field_value in actual_value.items():
|
||||
field_path = f"{path}.{field_name}"
|
||||
field_schema = properties.get(field_name)
|
||||
if isinstance(field_schema, dict):
|
||||
errors.extend(_validate_schema(field_schema, field_value, field_path))
|
||||
continue
|
||||
|
||||
additional_properties = schema.get("additionalProperties", True)
|
||||
if additional_properties is False:
|
||||
errors.append(f"{field_path} 是未定义字段")
|
||||
elif isinstance(additional_properties, dict):
|
||||
errors.extend(_validate_schema(additional_properties, field_value, field_path))
|
||||
return errors
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
|
||||
"""递归比较实际值与期望值是否完全一致。"""
|
||||
errors: List[str] = []
|
||||
|
||||
if isinstance(expected_value, dict):
|
||||
if not isinstance(actual_value, dict):
|
||||
return [f"{path} 值不一致,期望 object,实际为 {type(actual_value).__name__}"]
|
||||
|
||||
expected_keys = set(expected_value.keys())
|
||||
actual_keys = set(actual_value.keys())
|
||||
for missing_key in sorted(expected_keys - actual_keys):
|
||||
errors.append(f"{path}.{missing_key} 缺少期望字段")
|
||||
for extra_key in sorted(actual_keys - expected_keys):
|
||||
errors.append(f"{path}.{extra_key} 出现了额外字段")
|
||||
for shared_key in sorted(expected_keys & actual_keys):
|
||||
errors.extend(
|
||||
_compare_expected_values(
|
||||
expected_value[shared_key],
|
||||
actual_value[shared_key],
|
||||
f"{path}.{shared_key}",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, list):
|
||||
if not isinstance(actual_value, list):
|
||||
return [f"{path} 值不一致,期望 array,实际为 {type(actual_value).__name__}"]
|
||||
|
||||
if len(expected_value) != len(actual_value):
|
||||
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
|
||||
for index, (expected_item, actual_item) in enumerate(
|
||||
zip(expected_value, actual_value, strict=False)
|
||||
):
|
||||
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, bool):
|
||||
if not isinstance(actual_value, bool) or actual_value is not expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if _is_integer_value(expected_value):
|
||||
if not _is_integer_value(actual_value) or actual_value != expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if isinstance(expected_value, float):
|
||||
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
if expected_value != actual_value:
|
||||
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
|
||||
return errors
|
||||
|
||||
|
||||
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
|
||||
"""优先选择同名工具调用,否则回退到第一条。"""
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.func_name == expected_tool_name:
|
||||
return tool_call
|
||||
return tool_calls[0]
|
||||
|
||||
|
||||
def _validate_service_result(
|
||||
service_result: LLMServiceResult,
|
||||
target: ProbeTarget,
|
||||
case: ToolCallCase,
|
||||
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
||||
"""校验服务层返回结果。"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
completion = service_result.completion
|
||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
||||
|
||||
if not service_result.success:
|
||||
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if completion.model_name and completion.model_name != target.model_name:
|
||||
errors.append(
|
||||
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
|
||||
)
|
||||
|
||||
tool_calls = completion.tool_calls or []
|
||||
if not tool_calls:
|
||||
errors.append("模型未返回 tool_calls")
|
||||
if completion.response.strip():
|
||||
warnings.append("模型返回了自然语言文本而不是工具调用")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if len(tool_calls) != 1:
|
||||
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls,预期为 1 个")
|
||||
|
||||
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
|
||||
if selected_tool_call.func_name != case.tool_name:
|
||||
errors.append(
|
||||
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
|
||||
)
|
||||
|
||||
actual_arguments = selected_tool_call.args
|
||||
if not isinstance(actual_arguments, dict):
|
||||
errors.append("工具参数未被解析为对象")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
|
||||
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
|
||||
|
||||
if completion.response.strip():
|
||||
warnings.append("模型同时返回了自然语言文本")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
|
||||
async def _run_single_probe(
|
||||
target: ProbeTarget,
|
||||
case: ToolCallCase,
|
||||
attempt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> ProbeResult:
|
||||
"""执行单次工具调用参数探测。"""
|
||||
request = LLMServiceRequest(
|
||||
task_name=target.task_name,
|
||||
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
|
||||
prompt=case.build_messages(),
|
||||
tool_options=[case.tool_definition],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
with _pin_task_to_model(target.task_name, target.model_name):
|
||||
service_result = await generate(request)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
|
||||
completion = service_result.completion
|
||||
return ProbeResult(
|
||||
task_name=target.task_name,
|
||||
target_model_name=target.model_name,
|
||||
actual_model_name=completion.model_name,
|
||||
provider_name=target.provider_name,
|
||||
client_type=target.client_type,
|
||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
||||
case_name=case.name,
|
||||
attempt=attempt,
|
||||
success=not errors,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
response_text=completion.response,
|
||||
reasoning_text=completion.reasoning,
|
||||
tool_calls=serialized_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
||||
"""打印待测试目标。"""
|
||||
print("待测试目标:")
|
||||
for index, target in enumerate(targets, start=1):
|
||||
print(
|
||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
||||
f"provider={target.provider_name} | client={target.client_type} | "
|
||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
||||
)
|
||||
|
||||
|
||||
def _print_available_targets() -> None:
|
||||
"""打印当前可用任务与模型。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
task_names = list(available_tasks.keys())
|
||||
|
||||
print("当前可用任务:")
|
||||
for task_name in task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
||||
|
||||
referenced_models = {
|
||||
model_name
|
||||
for task_config in available_tasks.values()
|
||||
for model_name in task_config.model_list
|
||||
}
|
||||
|
||||
print("\n当前配置中的模型:")
|
||||
for model_name, model_info in model_map.items():
|
||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
||||
print(
|
||||
f"- {model_name}: provider={model_info.api_provider}, "
|
||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
||||
)
|
||||
|
||||
|
||||
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
|
||||
"""根据参数筛选测试用例。"""
|
||||
all_cases = {case.name: case for case in _build_default_cases()}
|
||||
if not case_filters:
|
||||
return list(all_cases.values())
|
||||
|
||||
selected_cases: List[ToolCallCase] = []
|
||||
for case_name in case_filters:
|
||||
if case_name not in all_cases:
|
||||
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
|
||||
selected_cases.append(all_cases[case_name])
|
||||
return selected_cases
|
||||
|
||||
|
||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
||||
"""打印单次结果。"""
|
||||
status_text = "PASS" if result.success else "FAIL"
|
||||
print(
|
||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
||||
)
|
||||
if result.errors:
|
||||
for error in result.errors:
|
||||
print(f" ERROR: {error}")
|
||||
if result.warnings:
|
||||
for warning in result.warnings:
|
||||
print(f" WARN: {warning}")
|
||||
if result.tool_calls:
|
||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
||||
if show_response and result.response_text.strip():
|
||||
print(f" response: {result.response_text}")
|
||||
|
||||
|
||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
||||
"""构造结果摘要。"""
|
||||
total_count = len(results)
|
||||
passed_count = sum(1 for result in results if result.success)
|
||||
failed_count = total_count - passed_count
|
||||
failed_items = [
|
||||
{
|
||||
"model_name": result.target_model_name,
|
||||
"case_name": result.case_name,
|
||||
"attempt": result.attempt,
|
||||
"errors": list(result.errors),
|
||||
}
|
||||
for result in results
|
||||
if not result.success
|
||||
]
|
||||
return {
|
||||
"total": total_count,
|
||||
"passed": passed_count,
|
||||
"failed": failed_count,
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
|
||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
||||
"""将测试结果写入 JSON 文件。"""
|
||||
output_path = Path(json_out).expanduser().resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"summary": _build_summary(results),
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\n结果已写入: {output_path}")
|
||||
|
||||
|
||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
||||
"""执行所有探测请求。"""
|
||||
task_filters = _parse_multi_value_args(args.task)
|
||||
model_filters = _parse_multi_value_args(args.model)
|
||||
case_filters = _parse_multi_value_args(args.case)
|
||||
|
||||
selected_cases = _select_cases(case_filters)
|
||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
||||
|
||||
_print_targets(targets)
|
||||
print("")
|
||||
|
||||
results: List[ProbeResult] = []
|
||||
for target in targets:
|
||||
for attempt in range(1, args.repeat + 1):
|
||||
for case in selected_cases:
|
||||
print(
|
||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
||||
f"case={case.name}, attempt={attempt}"
|
||||
)
|
||||
result = await _run_single_probe(
|
||||
target=target,
|
||||
case=case,
|
||||
attempt=attempt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
_print_single_result(result, args.show_response)
|
||||
print("")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _build_parser() -> ArgumentParser:
|
||||
"""构造命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
|
||||
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
action="append",
|
||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case",
|
||||
action="append",
|
||||
help="指定测试用例名,可选 simple、nested;不传则运行全部默认用例",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="每个模型每个用例重复测试次数,默认 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="单次测试的最大输出 token 数,默认 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="单次测试温度,默认 0.0 以尽量提高稳定性",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-task",
|
||||
default="utils",
|
||||
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-out",
|
||||
help="可选,将结果写入指定 JSON 文件",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-targets",
|
||||
action="store_true",
|
||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-response",
|
||||
action="store_true",
|
||||
help="打印模型返回的自然语言文本内容",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
_ensure_utf8_console()
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.repeat < 1:
|
||||
parser.error("--repeat 必须大于等于 1")
|
||||
if args.max_tokens < 1:
|
||||
parser.error("--max-tokens 必须大于等于 1")
|
||||
|
||||
if args.list_targets:
|
||||
_print_available_targets()
|
||||
return 0
|
||||
|
||||
results = asyncio.run(_run_probes(args))
|
||||
summary = _build_summary(results)
|
||||
|
||||
print("测试摘要:")
|
||||
print(
|
||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
||||
)
|
||||
if summary["failed_items"]:
|
||||
print("失败明细:")
|
||||
for failed_item in summary["failed_items"]:
|
||||
print(
|
||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
_write_json_report(args.json_out, results)
|
||||
|
||||
return 0 if summary["failed"] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,777 +0,0 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
||||
from src.config.config import config_manager # noqa: E402
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
||||
from src.services.llm_service import generate # noqa: E402
|
||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeTarget:
|
||||
"""单个待测试模型目标。"""
|
||||
|
||||
task_name: str
|
||||
model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallScenario:
|
||||
"""工具调用 API 场景定义。"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
prompt: List[Dict[str, Any]]
|
||||
tool_options: List[Dict[str, Any]] | None = None
|
||||
expect_tool_calls: bool | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeResult:
|
||||
"""单次 API 探测结果。"""
|
||||
|
||||
task_name: str
|
||||
target_model_name: str
|
||||
actual_model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
case_name: str
|
||||
attempt: int
|
||||
success: bool
|
||||
elapsed_seconds: float
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
response_text: str = ""
|
||||
reasoning_text: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _ensure_utf8_console() -> None:
|
||||
"""尽量将控制台编码切换为 UTF-8。"""
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构造 OpenAI 风格 function tool。"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_probe_tools() -> List[Dict[str, Any]]:
|
||||
"""构造通用测试工具。"""
|
||||
weather_tool = _build_function_tool(
|
||||
name="lookup_weather",
|
||||
description="查询指定城市天气。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "温度单位",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
|
||||
},
|
||||
"required": ["city", "unit", "include_forecast"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
search_tool = _build_function_tool(
|
||||
name="search_docs",
|
||||
description="搜索内部知识库。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"},
|
||||
"top_k": {"type": "integer", "description": "返回条数"},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"description": "过滤条件",
|
||||
"properties": {
|
||||
"scope": {"type": "string", "description": "搜索范围"},
|
||||
"tag": {"type": "string", "description": "标签"},
|
||||
},
|
||||
"required": ["scope", "tag"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"required": ["query", "top_k", "filters"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
return [weather_tool, search_tool]
|
||||
|
||||
|
||||
def _build_default_scenarios() -> List[ToolCallScenario]:
|
||||
"""构造默认测试场景。"""
|
||||
tools = _build_probe_tools()
|
||||
weather_tool = tools[0]
|
||||
search_tool = tools[1]
|
||||
|
||||
history_tool_call = {
|
||||
"id": "call_hist_weather_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"arguments": {
|
||||
"city": "上海",
|
||||
"unit": "celsius",
|
||||
"include_forecast": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
nested_history_tool_call = {
|
||||
"id": "call_hist_search_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_docs",
|
||||
"arguments": {
|
||||
"query": "工具调用兼容性",
|
||||
"top_k": 3,
|
||||
"filters": {
|
||||
"scope": "internal",
|
||||
"tag": "tool-call",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return [
|
||||
ToolCallScenario(
|
||||
name="fresh_tool_call",
|
||||
description="首轮普通工具调用请求。",
|
||||
prompt=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你正在执行工具调用连通性测试。"
|
||||
"如果能调用工具,就优先调用最合适的工具。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请查询上海天气,并使用工具给出参数。",
|
||||
},
|
||||
],
|
||||
tool_options=[weather_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_with_content",
|
||||
description="历史 assistant 同时包含文本和 tool_calls,当前轮不再提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气,再继续回答。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续说,别丢掉上下文。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_without_content",
|
||||
description="历史 assistant 只有 tool_calls,没有文本内容。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_result_followup",
|
||||
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "多云",
|
||||
"temperature_c": 24,
|
||||
"forecast": ["晴", "小雨"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "结合上面的查询结果继续总结。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_multiple_tool_calls_and_results",
|
||||
description="历史中包含多个 tool_calls 与多条 tool 结果。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我分两步查询。",
|
||||
"tool_calls": [history_tool_call, nested_history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "阴",
|
||||
"temperature_c": 22,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_search_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"items": [
|
||||
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
|
||||
"部分 provider 在历史消息回放时兼容性较弱",
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "继续整合上面的两个结果。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_calls_with_current_tools",
|
||||
description="保留历史 tool_calls,同时当前轮仍然提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "晴",
|
||||
"temperature_c": 26,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
|
||||
],
|
||||
tool_options=[search_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
||||
"""解析命令行中的多值参数。"""
|
||||
parsed_values: List[str] = []
|
||||
for raw_value in raw_values or []:
|
||||
for item in str(raw_value).split(","):
|
||||
normalized_item = item.strip()
|
||||
if normalized_item:
|
||||
parsed_values.append(normalized_item)
|
||||
return parsed_values
|
||||
|
||||
|
||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
||||
"""构造模型名到模型配置的映射。"""
|
||||
return {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
|
||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
||||
"""构造 Provider 名称到配置的映射。"""
|
||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
||||
|
||||
|
||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
||||
"""选择默认任务名。"""
|
||||
if "utils" in task_names:
|
||||
return "utils"
|
||||
if not task_names:
|
||||
raise ValueError("当前没有可用的任务配置")
|
||||
return str(task_names[0])
|
||||
|
||||
|
||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
||||
"""根据命令行参数解析待测试目标。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
provider_map = _build_provider_map()
|
||||
|
||||
if not available_tasks:
|
||||
raise ValueError("未找到任何可用的模型任务配置")
|
||||
|
||||
if task_filters:
|
||||
selected_task_names = []
|
||||
for task_name in task_filters:
|
||||
if task_name not in available_tasks:
|
||||
raise ValueError(f"未找到任务 `{task_name}`")
|
||||
selected_task_names.append(task_name)
|
||||
else:
|
||||
selected_task_names = [
|
||||
task_name
|
||||
for task_name in available_tasks
|
||||
if task_name not in DEFAULT_SKIP_TASKS
|
||||
]
|
||||
|
||||
if not selected_task_names:
|
||||
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
|
||||
|
||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
||||
resolved_targets: List[ProbeTarget] = []
|
||||
seen_models: set[str] = set()
|
||||
|
||||
if model_filters:
|
||||
model_names = list(model_filters)
|
||||
else:
|
||||
model_names = []
|
||||
for task_name in selected_task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
for model_name in task_config.model_list:
|
||||
if model_name not in model_names:
|
||||
model_names.append(model_name)
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name in seen_models:
|
||||
continue
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"未找到模型 `{model_name}`")
|
||||
|
||||
target_task_name = ""
|
||||
for task_name in selected_task_names:
|
||||
if model_name in available_tasks[task_name].model_list:
|
||||
target_task_name = task_name
|
||||
break
|
||||
if not target_task_name:
|
||||
target_task_name = default_task_name
|
||||
|
||||
model_info = model_map[model_name]
|
||||
provider_info = provider_map[model_info.api_provider]
|
||||
resolved_targets.append(
|
||||
ProbeTarget(
|
||||
task_name=target_task_name,
|
||||
model_name=model_name,
|
||||
provider_name=provider_info.name,
|
||||
client_type=provider_info.client_type,
|
||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
||||
)
|
||||
)
|
||||
seen_models.add(model_name)
|
||||
|
||||
return resolved_targets
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
||||
"""临时将某个任务锁定到单模型。"""
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
task_config = getattr(model_task_config, task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
||||
|
||||
original_model_list = list(task_config.model_list)
|
||||
original_selection_strategy = task_config.selection_strategy
|
||||
task_config.model_list = [model_name]
|
||||
task_config.selection_strategy = "balance"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_config.model_list = original_model_list
|
||||
task_config.selection_strategy = original_selection_strategy
|
||||
|
||||
|
||||
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
|
||||
"""序列化返回中的工具调用。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
serialized_items: List[Dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
serialized_items.append(
|
||||
{
|
||||
"id": getattr(tool_call, "call_id", ""),
|
||||
"function": {
|
||||
"name": getattr(tool_call, "func_name", ""),
|
||||
"arguments": dict(getattr(tool_call, "args", {}) or {}),
|
||||
},
|
||||
**(
|
||||
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
|
||||
if getattr(tool_call, "extra_content", None)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return serialized_items
|
||||
|
||||
|
||||
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
||||
"""校验服务结果。"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
completion = service_result.completion
|
||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
||||
|
||||
if not service_result.success:
|
||||
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if scenario.expect_tool_calls is True and not serialized_tool_calls:
|
||||
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
|
||||
if scenario.expect_tool_calls is False and serialized_tool_calls:
|
||||
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
|
||||
if completion.response.strip():
|
||||
warnings.append("模型返回了可见文本")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
|
||||
async def _run_single_probe(
|
||||
target: ProbeTarget,
|
||||
scenario: ToolCallScenario,
|
||||
attempt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> ProbeResult:
|
||||
"""执行单次 API 探测。"""
|
||||
request = LLMServiceRequest(
|
||||
task_name=target.task_name,
|
||||
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
|
||||
prompt=scenario.prompt,
|
||||
tool_options=scenario.tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
with _pin_task_to_model(target.task_name, target.model_name):
|
||||
service_result = await generate(request)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
|
||||
completion = service_result.completion
|
||||
return ProbeResult(
|
||||
task_name=target.task_name,
|
||||
target_model_name=target.model_name,
|
||||
actual_model_name=completion.model_name,
|
||||
provider_name=target.provider_name,
|
||||
client_type=target.client_type,
|
||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
||||
case_name=scenario.name,
|
||||
attempt=attempt,
|
||||
success=not errors,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
response_text=completion.response,
|
||||
reasoning_text=completion.reasoning,
|
||||
tool_calls=serialized_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
||||
"""打印待测试目标。"""
|
||||
print("待测试目标:")
|
||||
for index, target in enumerate(targets, start=1):
|
||||
print(
|
||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
||||
f"provider={target.provider_name} | client={target.client_type} | "
|
||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
||||
)
|
||||
|
||||
|
||||
def _print_available_targets() -> None:
|
||||
"""打印当前可用任务与模型。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
task_names = list(available_tasks.keys())
|
||||
|
||||
print("当前可用任务:")
|
||||
for task_name in task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
||||
|
||||
referenced_models = {
|
||||
model_name
|
||||
for task_config in available_tasks.values()
|
||||
for model_name in task_config.model_list
|
||||
}
|
||||
|
||||
print("\n当前配置中的模型:")
|
||||
for model_name, model_info in model_map.items():
|
||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
||||
print(
|
||||
f"- {model_name}: provider={model_info.api_provider}, "
|
||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
||||
)
|
||||
|
||||
|
||||
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
|
||||
"""按名称筛选测试场景。"""
|
||||
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
|
||||
if not case_filters:
|
||||
return list(all_scenarios.values())
|
||||
|
||||
selected_scenarios: List[ToolCallScenario] = []
|
||||
for case_name in case_filters:
|
||||
if case_name not in all_scenarios:
|
||||
raise ValueError(
|
||||
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
|
||||
)
|
||||
selected_scenarios.append(all_scenarios[case_name])
|
||||
return selected_scenarios
|
||||
|
||||
|
||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
||||
"""打印单次结果。"""
|
||||
status_text = "PASS" if result.success else "FAIL"
|
||||
print(
|
||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
||||
)
|
||||
if result.errors:
|
||||
for error in result.errors:
|
||||
print(f" ERROR: {error}")
|
||||
if result.warnings:
|
||||
for warning in result.warnings:
|
||||
print(f" WARN: {warning}")
|
||||
if result.tool_calls:
|
||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
||||
if show_response and result.response_text.strip():
|
||||
print(f" response: {result.response_text}")
|
||||
|
||||
|
||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
||||
"""构造结果摘要。"""
|
||||
total_count = len(results)
|
||||
passed_count = sum(1 for result in results if result.success)
|
||||
failed_count = total_count - passed_count
|
||||
failed_items = [
|
||||
{
|
||||
"model_name": result.target_model_name,
|
||||
"case_name": result.case_name,
|
||||
"attempt": result.attempt,
|
||||
"errors": list(result.errors),
|
||||
}
|
||||
for result in results
|
||||
if not result.success
|
||||
]
|
||||
return {
|
||||
"total": total_count,
|
||||
"passed": passed_count,
|
||||
"failed": failed_count,
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
|
||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
||||
"""将测试结果写入 JSON 文件。"""
|
||||
output_path = Path(json_out).expanduser().resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"summary": _build_summary(results),
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\n结果已写入: {output_path}")
|
||||
|
||||
|
||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
||||
"""执行所有探测请求。"""
|
||||
task_filters = _parse_multi_value_args(args.task)
|
||||
model_filters = _parse_multi_value_args(args.model)
|
||||
case_filters = _parse_multi_value_args(args.case)
|
||||
|
||||
selected_scenarios = _select_scenarios(case_filters)
|
||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
||||
|
||||
_print_targets(targets)
|
||||
print("")
|
||||
|
||||
results: List[ProbeResult] = []
|
||||
for target in targets:
|
||||
for attempt in range(1, args.repeat + 1):
|
||||
for scenario in selected_scenarios:
|
||||
print(
|
||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
||||
f"case={scenario.name}, attempt={attempt}"
|
||||
)
|
||||
result = await _run_single_probe(
|
||||
target=target,
|
||||
scenario=scenario,
|
||||
attempt=attempt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
_print_single_result(result, args.show_response)
|
||||
print("")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _build_parser() -> ArgumentParser:
|
||||
"""构造命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
|
||||
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
action="append",
|
||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case",
|
||||
action="append",
|
||||
help=(
|
||||
"指定测试场景名,可选值包括 "
|
||||
"fresh_tool_call、history_assistant_tool_calls_with_content、"
|
||||
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
|
||||
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="每个模型每个场景重复测试次数,默认 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="单次测试的最大输出 token 数,默认 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="单次测试温度,默认 0.0,以尽量提高稳定性",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-task",
|
||||
default="utils",
|
||||
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-out",
|
||||
help="可选,将结果写入指定 JSON 文件",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-targets",
|
||||
action="store_true",
|
||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-response",
|
||||
action="store_true",
|
||||
help="打印模型返回的可见文本内容",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
_ensure_utf8_console()
|
||||
config_manager.initialize()
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.repeat < 1:
|
||||
parser.error("--repeat 必须大于等于 1")
|
||||
if args.max_tokens < 1:
|
||||
parser.error("--max-tokens 必须大于等于 1")
|
||||
|
||||
if args.list_targets:
|
||||
_print_available_targets()
|
||||
return 0
|
||||
|
||||
results = asyncio.run(_run_probes(args))
|
||||
summary = _build_summary(results)
|
||||
|
||||
print("测试摘要:")
|
||||
print(
|
||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
||||
)
|
||||
if summary["failed_items"]:
|
||||
print("失败明细:")
|
||||
for failed_item in summary["failed_items"]:
|
||||
print(
|
||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
_write_json_report(args.json_out, results)
|
||||
|
||||
return 0 if summary["failed"] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user