fix:正常获取Bot自身发送的消息id

This commit is contained in:
SengokuCola
2026-04-05 13:05:18 +08:00
parent ead90cbdf3
commit 18d48e0145
8 changed files with 1147 additions and 46 deletions

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from types import SimpleNamespace
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence, ReplyComponent, TextComponent
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
def _build_sent_message() -> SessionMessage:
message = SessionMessage(
message_id="real-message-id",
timestamp=datetime(2026, 4, 5, 12, 0, 0),
platform="qq",
)
message.message_info = MessageInfo(
user_info=UserInfo(
user_id="bot-qq",
user_nickname="MaiSaka",
user_cardname=None,
),
group_info=None,
additional_config={},
)
message.raw_message = MessageSequence(
[
ReplyComponent(target_message_id="m123"),
TextComponent(text="你好"),
]
)
message.session_id = "test-session"
message.initialized = True
return message
def test_append_sent_message_to_chat_history_keeps_message_id() -> None:
runtime = SimpleNamespace(_chat_history=[])
engine = SimpleNamespace(_get_runtime_manager=lambda: None)
tool_ctx = BuiltinToolRuntimeContext(engine=engine, runtime=runtime)
tool_ctx.append_sent_message_to_chat_history(_build_sent_message())
assert len(runtime._chat_history) == 1
history_message = runtime._chat_history[0]
assert history_message.message_id == "real-message-id"
assert "[msg_id]real-message-id\n" in history_message.raw_message.components[0].text
assert "[msg_id:real-message-id]" in history_message.visible_text

View File

@@ -142,6 +142,46 @@ async def test_text_to_stream_delegates_to_platform_io(monkeypatch: pytest.Monke
]
@pytest.mark.asyncio
async def test_text_to_stream_with_message_returns_sent_message(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(
delivery_batch=SimpleNamespace(
has_success=True,
sent_receipts=[
SimpleNamespace(
driver_id="plugin.qq.sender",
external_message_id="real-message-id",
metadata={},
)
],
failed_receipts=[],
route_key=SimpleNamespace(platform="qq"),
)
)
stored_messages: List[Any] = []
monkeypatch.setattr(send_service, "get_platform_io_manager", lambda: fake_manager)
monkeypatch.setattr(send_service, "get_bot_account", lambda platform: "bot-qq")
monkeypatch.setattr(
send_service._chat_manager,
"get_session_by_session_id",
lambda stream_id: _build_private_stream() if stream_id == "test-session" else None,
)
monkeypatch.setattr(
send_service.MessageUtils,
"store_message_to_db",
lambda message: stored_messages.append(message),
)
sent_message = await send_service.text_to_stream_with_message(text="你好", stream_id="test-session")
assert sent_message is not None
assert sent_message.message_id == "real-message-id"
assert fake_manager.ensure_calls == 1
assert len(stored_messages) == 1
assert stored_messages[0].message_id == "real-message-id"
@pytest.mark.asyncio
async def test_text_to_stream_returns_false_when_platform_io_fails(monkeypatch: pytest.MonkeyPatch) -> None:
fake_manager = _FakePlatformIOManager(

View File

@@ -0,0 +1,845 @@
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Sequence
import asyncio
import json
import sys
import time
PROJECT_ROOT = Path(__file__).resolve().parent.parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
from src.config.config import config_manager # noqa: E402
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
from src.llm_models.payload_content.tool_option import ToolCall # noqa: E402
from src.services.llm_service import generate # noqa: E402
from src.services.service_task_resolver import get_available_models # noqa: E402
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
@dataclass(slots=True)
class ToolCallCase:
"""Tool call 参数测试用例。"""
name: str
description: str
tool_definition: Dict[str, Any]
expected_arguments: Dict[str, Any]
@property
def tool_name(self) -> str:
"""返回工具名称。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
return str(function_definition.get("name", "") or "")
return str(self.tool_definition.get("name", "") or "")
@property
def parameters_schema(self) -> Dict[str, Any]:
"""返回参数 Schema。"""
if self.tool_definition.get("type") == "function":
function_definition = self.tool_definition.get("function", {})
parameters = function_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
parameters = self.tool_definition.get("parameters", {})
return parameters if isinstance(parameters, dict) else {}
def build_messages(self) -> List[Dict[str, Any]]:
"""构造测试消息。"""
expected_json = json.dumps(self.expected_arguments, ensure_ascii=False, indent=2)
system_prompt = (
"你正在执行严格的工具调用参数兼容性测试。"
"你必须通过工具调用响应,不能输出自然语言,不能解释,不能补充额外字段。"
)
user_prompt = (
f"请立刻调用工具 `{self.tool_name}`。\n"
"参数必须与下面 JSON 完全一致,键名、值、布尔类型、整数类型、浮点数、数组顺序和对象结构都不能改变。\n"
"不要输出任何解释文本,只返回工具调用。\n"
f"{expected_json}"
)
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
@dataclass(slots=True)
class ProbeTarget:
"""单个待测试模型目标。"""
task_name: str
model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
@dataclass(slots=True)
class ProbeResult:
"""单次测试结果。"""
task_name: str
target_model_name: str
actual_model_name: str
provider_name: str
client_type: str
tool_argument_parse_mode: str
case_name: str
attempt: int
success: bool
elapsed_seconds: float
errors: List[str]
warnings: List[str]
response_text: str
reasoning_text: str
tool_calls: List[Dict[str, Any]]
def _ensure_utf8_console() -> None:
"""尽量将控制台编码切到 UTF-8。"""
try:
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
except Exception:
pass
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
"""构造 OpenAI 风格 function tool 定义。"""
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters,
},
}
def _build_default_cases() -> List[ToolCallCase]:
"""构造默认测试用例。"""
simple_expected_arguments = {
"request_id": "probe-simple-001",
"count": 7,
"enabled": True,
"mode": "strict",
"ratio": 2.5,
}
simple_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"count": {"type": "integer", "description": "数量"},
"enabled": {"type": "boolean", "description": "是否启用"},
"mode": {
"type": "string",
"description": "模式",
"enum": ["strict", "loose"],
},
"ratio": {"type": "number", "description": "比例"},
},
"required": ["request_id", "count", "enabled", "mode", "ratio"],
"additionalProperties": False,
}
nested_expected_arguments = {
"request_id": "probe-nested-001",
"notify": False,
"profile": {
"channel": "stable",
"priority": 2,
},
"tags": ["alpha", "beta", "gamma"],
"items": [
{"count": 2, "name": "apple"},
{"count": 5, "name": "banana"},
],
}
nested_parameters = {
"type": "object",
"properties": {
"request_id": {"type": "string", "description": "请求 ID"},
"notify": {"type": "boolean", "description": "是否通知"},
"profile": {
"type": "object",
"description": "配置对象",
"properties": {
"channel": {"type": "string", "description": "渠道"},
"priority": {"type": "integer", "description": "优先级"},
},
"required": ["channel", "priority"],
"additionalProperties": False,
},
"tags": {
"type": "array",
"description": "标签列表",
"items": {"type": "string"},
},
"items": {
"type": "array",
"description": "条目列表",
"items": {
"type": "object",
"properties": {
"count": {"type": "integer", "description": "数量"},
"name": {"type": "string", "description": "名称"},
},
"required": ["count", "name"],
"additionalProperties": False,
},
},
},
"required": ["request_id", "notify", "profile", "tags", "items"],
"additionalProperties": False,
}
return [
ToolCallCase(
name="simple",
description="标量参数类型校验",
tool_definition=_build_function_tool(
name="record_simple_probe",
description="记录简单参数探测结果",
parameters=simple_parameters,
),
expected_arguments=simple_expected_arguments,
),
ToolCallCase(
name="nested",
description="嵌套对象与数组参数校验",
tool_definition=_build_function_tool(
name="record_nested_probe",
description="记录嵌套参数探测结果",
parameters=nested_parameters,
),
expected_arguments=nested_expected_arguments,
),
]
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
"""解析命令行中的多值参数。"""
parsed_values: List[str] = []
for raw_value in raw_values or []:
for item in str(raw_value).split(","):
normalized_item = item.strip()
if normalized_item:
parsed_values.append(normalized_item)
return parsed_values
def _build_model_map() -> Dict[str, ModelInfo]:
"""构造模型名称到模型配置的映射。"""
return {model.name: model for model in config_manager.get_model_config().models}
def _build_provider_map() -> Dict[str, APIProvider]:
"""构造 Provider 名称到配置的映射。"""
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
def _pick_default_task_name(task_names: Sequence[str]) -> str:
"""选择默认任务名。"""
if "utils" in task_names:
return "utils"
if not task_names:
raise ValueError("当前没有可用的任务配置")
return str(task_names[0])
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
"""根据命令行参数解析待测试目标。"""
available_tasks = get_available_models()
model_map = _build_model_map()
provider_map = _build_provider_map()
if not available_tasks:
raise ValueError("未找到任何可用的模型任务配置")
if task_filters:
selected_task_names = []
for task_name in task_filters:
if task_name not in available_tasks:
raise ValueError(f"未找到任务 `{task_name}`")
selected_task_names.append(task_name)
else:
selected_task_names = [
task_name
for task_name in available_tasks
if task_name not in DEFAULT_SKIP_TASKS
]
if not selected_task_names:
raise ValueError("没有可用于 tool call 测试的任务,请显式通过 --task 指定")
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
resolved_targets: List[ProbeTarget] = []
seen_models: set[str] = set()
if model_filters:
model_names = list(model_filters)
else:
model_names = []
for task_name in selected_task_names:
task_config = available_tasks[task_name]
for model_name in task_config.model_list:
if model_name not in model_names:
model_names.append(model_name)
for model_name in model_names:
if model_name in seen_models:
continue
if model_name not in model_map:
raise ValueError(f"未找到模型 `{model_name}`")
target_task_name = ""
for task_name in selected_task_names:
if model_name in available_tasks[task_name].model_list:
target_task_name = task_name
break
if not target_task_name:
target_task_name = default_task_name
model_info = model_map[model_name]
provider_info = provider_map[model_info.api_provider]
resolved_targets.append(
ProbeTarget(
task_name=target_task_name,
model_name=model_name,
provider_name=provider_info.name,
client_type=provider_info.client_type,
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
)
)
seen_models.add(model_name)
return resolved_targets
@contextmanager
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
"""临时将某个任务锁定到单模型。"""
model_task_config = config_manager.get_model_config().model_task_config
task_config = getattr(model_task_config, task_name, None)
if not isinstance(task_config, TaskConfig):
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
original_model_list = list(task_config.model_list)
original_selection_strategy = task_config.selection_strategy
task_config.model_list = [model_name]
task_config.selection_strategy = "balance"
try:
yield
finally:
task_config.model_list = original_model_list
task_config.selection_strategy = original_selection_strategy
def _serialize_tool_calls(tool_calls: List[ToolCall] | None) -> List[Dict[str, Any]]:
"""序列化工具调用结果。"""
if not tool_calls:
return []
return [
{
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": dict(tool_call.args or {}),
},
}
for tool_call in tool_calls
]
def _is_integer_value(value: Any) -> bool:
"""判断是否为整数类型且排除布尔值。"""
return isinstance(value, int) and not isinstance(value, bool)
def _is_number_value(value: Any) -> bool:
"""判断是否为数值类型且排除布尔值。"""
return (isinstance(value, int) or isinstance(value, float)) and not isinstance(value, bool)
def _schema_type(schema: Dict[str, Any]) -> str:
"""解析 Schema 的类型。"""
schema_type = str(schema.get("type", "") or "").strip()
if schema_type:
return schema_type
if "properties" in schema or "required" in schema:
return "object"
return ""
def _validate_schema(schema: Dict[str, Any], actual_value: Any, path: str = "args") -> List[str]:
"""按简化 JSON Schema 校验工具参数。"""
errors: List[str] = []
schema_type = _schema_type(schema)
if "enum" in schema and actual_value not in schema["enum"]:
errors.append(f"{path} 枚举值不合法,期望属于 {schema['enum']},实际为 {actual_value!r}")
if schema_type == "string":
if not isinstance(actual_value, str):
errors.append(f"{path} 类型错误,期望 string实际为 {type(actual_value).__name__}")
return errors
if schema_type == "integer":
if not _is_integer_value(actual_value):
errors.append(f"{path} 类型错误,期望 integer实际为 {type(actual_value).__name__}")
return errors
if schema_type == "number":
if not _is_number_value(actual_value):
errors.append(f"{path} 类型错误,期望 number实际为 {type(actual_value).__name__}")
return errors
if schema_type == "boolean":
if not isinstance(actual_value, bool):
errors.append(f"{path} 类型错误,期望 boolean实际为 {type(actual_value).__name__}")
return errors
if schema_type == "array":
if not isinstance(actual_value, list):
errors.append(f"{path} 类型错误,期望 array实际为 {type(actual_value).__name__}")
return errors
item_schema = schema.get("items")
if isinstance(item_schema, dict):
for index, item in enumerate(actual_value):
errors.extend(_validate_schema(item_schema, item, f"{path}[{index}]"))
return errors
if schema_type == "object":
if not isinstance(actual_value, dict):
errors.append(f"{path} 类型错误,期望 object实际为 {type(actual_value).__name__}")
return errors
properties = schema.get("properties", {})
required_fields = [str(item) for item in schema.get("required", [])]
for required_field in required_fields:
if required_field not in actual_value:
errors.append(f"{path}.{required_field} 缺少必填字段")
for field_name, field_value in actual_value.items():
field_path = f"{path}.{field_name}"
field_schema = properties.get(field_name)
if isinstance(field_schema, dict):
errors.extend(_validate_schema(field_schema, field_value, field_path))
continue
additional_properties = schema.get("additionalProperties", True)
if additional_properties is False:
errors.append(f"{field_path} 是未定义字段")
elif isinstance(additional_properties, dict):
errors.extend(_validate_schema(additional_properties, field_value, field_path))
return errors
return errors
def _compare_expected_values(expected_value: Any, actual_value: Any, path: str = "args") -> List[str]:
"""递归比较实际值与期望值是否完全一致。"""
errors: List[str] = []
if isinstance(expected_value, dict):
if not isinstance(actual_value, dict):
return [f"{path} 值不一致,期望 object实际为 {type(actual_value).__name__}"]
expected_keys = set(expected_value.keys())
actual_keys = set(actual_value.keys())
for missing_key in sorted(expected_keys - actual_keys):
errors.append(f"{path}.{missing_key} 缺少期望字段")
for extra_key in sorted(actual_keys - expected_keys):
errors.append(f"{path}.{extra_key} 出现了额外字段")
for shared_key in sorted(expected_keys & actual_keys):
errors.extend(
_compare_expected_values(
expected_value[shared_key],
actual_value[shared_key],
f"{path}.{shared_key}",
)
)
return errors
if isinstance(expected_value, list):
if not isinstance(actual_value, list):
return [f"{path} 值不一致,期望 array实际为 {type(actual_value).__name__}"]
if len(expected_value) != len(actual_value):
errors.append(f"{path} 列表长度不一致,期望 {len(expected_value)},实际 {len(actual_value)}")
for index, (expected_item, actual_item) in enumerate(
zip(expected_value, actual_value, strict=False)
):
errors.extend(_compare_expected_values(expected_item, actual_item, f"{path}[{index}]"))
return errors
if isinstance(expected_value, bool):
if not isinstance(actual_value, bool) or actual_value is not expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if _is_integer_value(expected_value):
if not _is_integer_value(actual_value) or actual_value != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if isinstance(expected_value, float):
if not _is_number_value(actual_value) or float(actual_value) != expected_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
if expected_value != actual_value:
errors.append(f"{path} 值不一致,期望 {expected_value!r},实际 {actual_value!r}")
return errors
def _pick_tool_call(tool_calls: List[ToolCall], expected_tool_name: str) -> ToolCall:
"""优先选择同名工具调用,否则回退到第一条。"""
for tool_call in tool_calls:
if tool_call.func_name == expected_tool_name:
return tool_call
return tool_calls[0]
def _validate_service_result(
service_result: LLMServiceResult,
target: ProbeTarget,
case: ToolCallCase,
) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
"""校验服务层返回结果。"""
errors: List[str] = []
warnings: List[str] = []
completion = service_result.completion
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
if not service_result.success:
errors.append(service_result.error or completion.response or "请求失败但未返回错误信息")
return errors, warnings, serialized_tool_calls
if completion.model_name and completion.model_name != target.model_name:
errors.append(
f"实际命中的模型为 `{completion.model_name}`,与目标模型 `{target.model_name}` 不一致"
)
tool_calls = completion.tool_calls or []
if not tool_calls:
errors.append("模型未返回 tool_calls")
if completion.response.strip():
warnings.append("模型返回了自然语言文本而不是工具调用")
return errors, warnings, serialized_tool_calls
if len(tool_calls) != 1:
errors.append(f"返回了 {len(tool_calls)} 个 tool_calls预期为 1 个")
selected_tool_call = _pick_tool_call(tool_calls, case.tool_name)
if selected_tool_call.func_name != case.tool_name:
errors.append(
f"工具名不一致,期望 `{case.tool_name}`,实际 `{selected_tool_call.func_name}`"
)
actual_arguments = selected_tool_call.args
if not isinstance(actual_arguments, dict):
errors.append("工具参数未被解析为对象")
return errors, warnings, serialized_tool_calls
errors.extend(_validate_schema(case.parameters_schema, actual_arguments))
errors.extend(_compare_expected_values(case.expected_arguments, actual_arguments))
if completion.response.strip():
warnings.append("模型同时返回了自然语言文本")
return errors, warnings, serialized_tool_calls
async def _run_single_probe(
target: ProbeTarget,
case: ToolCallCase,
attempt: int,
max_tokens: int,
temperature: float,
) -> ProbeResult:
"""执行单次工具调用参数探测。"""
request = LLMServiceRequest(
task_name=target.task_name,
request_type=f"tool_call_param_probe.{case.name}.attempt_{attempt}",
prompt=case.build_messages(),
tool_options=[case.tool_definition],
temperature=temperature,
max_tokens=max_tokens,
)
started_at = time.perf_counter()
with _pin_task_to_model(target.task_name, target.model_name):
service_result = await generate(request)
elapsed_seconds = time.perf_counter() - started_at
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, target, case)
completion = service_result.completion
return ProbeResult(
task_name=target.task_name,
target_model_name=target.model_name,
actual_model_name=completion.model_name,
provider_name=target.provider_name,
client_type=target.client_type,
tool_argument_parse_mode=target.tool_argument_parse_mode,
case_name=case.name,
attempt=attempt,
success=not errors,
elapsed_seconds=elapsed_seconds,
errors=errors,
warnings=warnings,
response_text=completion.response,
reasoning_text=completion.reasoning,
tool_calls=serialized_tool_calls,
)
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
"""打印待测试目标。"""
print("待测试目标:")
for index, target in enumerate(targets, start=1):
print(
f"{index}. model={target.model_name} | task={target.task_name} | "
f"provider={target.provider_name} | client={target.client_type} | "
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
)
def _print_available_targets() -> None:
"""打印当前可用任务与模型。"""
available_tasks = get_available_models()
model_map = _build_model_map()
task_names = list(available_tasks.keys())
print("当前可用任务:")
for task_name in task_names:
task_config = available_tasks[task_name]
print(f"- {task_name}: {list(task_config.model_list)}")
referenced_models = {
model_name
for task_config in available_tasks.values()
for model_name in task_config.model_list
}
print("\n当前配置中的模型:")
for model_name, model_info in model_map.items():
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
print(
f"- {model_name}: provider={model_info.api_provider}, "
f"identifier={model_info.model_identifier}, {referenced_mark}"
)
def _select_cases(case_filters: Sequence[str]) -> List[ToolCallCase]:
"""根据参数筛选测试用例。"""
all_cases = {case.name: case for case in _build_default_cases()}
if not case_filters:
return list(all_cases.values())
selected_cases: List[ToolCallCase] = []
for case_name in case_filters:
if case_name not in all_cases:
raise ValueError(f"未知测试用例 `{case_name}`,可选值: {', '.join(sorted(all_cases))}")
selected_cases.append(all_cases[case_name])
return selected_cases
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
"""打印单次结果。"""
status_text = "PASS" if result.success else "FAIL"
print(
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
)
if result.errors:
for error in result.errors:
print(f" ERROR: {error}")
if result.warnings:
for warning in result.warnings:
print(f" WARN: {warning}")
if result.tool_calls:
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
if show_response and result.response_text.strip():
print(f" response: {result.response_text}")
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
"""构造结果摘要。"""
total_count = len(results)
passed_count = sum(1 for result in results if result.success)
failed_count = total_count - passed_count
failed_items = [
{
"model_name": result.target_model_name,
"case_name": result.case_name,
"attempt": result.attempt,
"errors": list(result.errors),
}
for result in results
if not result.success
]
return {
"total": total_count,
"passed": passed_count,
"failed": failed_count,
"failed_items": failed_items,
}
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
"""将测试结果写入 JSON 文件。"""
output_path = Path(json_out).expanduser().resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"summary": _build_summary(results),
"results": [asdict(result) for result in results],
}
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f"\n结果已写入: {output_path}")
async def _run_probes(args: Namespace) -> List[ProbeResult]:
"""执行所有探测请求。"""
task_filters = _parse_multi_value_args(args.task)
model_filters = _parse_multi_value_args(args.model)
case_filters = _parse_multi_value_args(args.case)
selected_cases = _select_cases(case_filters)
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
_print_targets(targets)
print("")
results: List[ProbeResult] = []
for target in targets:
for attempt in range(1, args.repeat + 1):
for case in selected_cases:
print(
f"开始测试: model={target.model_name}, task={target.task_name}, "
f"case={case.name}, attempt={attempt}"
)
result = await _run_single_probe(
target=target,
case=case,
attempt=attempt,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
_print_single_result(result, args.show_response)
print("")
results.append(result)
return results
def _build_parser() -> ArgumentParser:
"""构造命令行参数解析器。"""
parser = ArgumentParser(
description=(
"测试 config/model_config.toml 中不同模型的 tool call 参数兼容性。\n"
"默认会测试所有非 voice / embedding 任务中引用到的模型。"
)
)
parser.add_argument(
"--task",
action="append",
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
)
parser.add_argument(
"--model",
action="append",
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.6-plus",
)
parser.add_argument(
"--case",
action="append",
help="指定测试用例名,可选 simple、nested不传则运行全部默认用例",
)
parser.add_argument(
"--repeat",
type=int,
default=1,
help="每个模型每个用例重复测试次数,默认 1",
)
parser.add_argument(
"--max-tokens",
type=int,
default=512,
help="单次测试的最大输出 token 数,默认 512",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="单次测试温度,默认 0.0 以尽量提高稳定性",
)
parser.add_argument(
"--fallback-task",
default="utils",
help="当指定模型未被任何已选任务引用时,用于挂载该模型的任务名,默认 utils",
)
parser.add_argument(
"--json-out",
help="可选,将结果写入指定 JSON 文件",
)
parser.add_argument(
"--list-targets",
action="store_true",
help="仅打印当前任务与模型映射,不发起网络请求",
)
parser.add_argument(
"--show-response",
action="store_true",
help="打印模型返回的自然语言文本内容",
)
return parser
def main() -> int:
"""脚本入口。"""
_ensure_utf8_console()
parser = _build_parser()
args = parser.parse_args()
if args.repeat < 1:
parser.error("--repeat 必须大于等于 1")
if args.max_tokens < 1:
parser.error("--max-tokens 必须大于等于 1")
if args.list_targets:
_print_available_targets()
return 0
results = asyncio.run(_run_probes(args))
summary = _build_summary(results)
print("测试摘要:")
print(
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
)
if summary["failed_items"]:
print("失败明细:")
for failed_item in summary["failed_items"]:
print(
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
)
if args.json_out:
_write_json_report(args.json_out, results)
return 0 if summary["failed"] == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -2,7 +2,7 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence
from typing import Any, Optional, Sequence, TYPE_CHECKING
import random
@@ -18,6 +18,9 @@ from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manag
logger = get_logger("emoji_maisaka_tool")
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
EmojiSelector = Callable[
[str, str, Sequence[str] | None, int],
Awaitable[tuple[MaiEmoji | None, str]],
@@ -35,6 +38,7 @@ class MaisakaEmojiSendResult:
emotions: list[str] = field(default_factory=list)
requested_emotion: str = ""
matched_emotion: str = ""
sent_message: Optional["SessionMessage"] = None
def _get_runtime_manager() -> Any:
@@ -309,6 +313,7 @@ async def send_emoji_for_maisaka(
try:
target_session = chat_manager.get_session_by_session_id(stream_id)
sent_message = None
if target_session is not None and target_session.platform == CLI_PLATFORM_NAME:
preview_message = (
f"已发送表情包:{selected_emoji.description.strip()}"
@@ -318,13 +323,14 @@ async def send_emoji_for_maisaka(
render_cli_message(preview_message)
sent = True
else:
sent = await send_service.emoji_to_stream(
sent_message = await send_service.emoji_to_stream_with_message(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=True,
set_reply=False,
reply_message=None,
)
sent = sent_message is not None
except Exception as exc:
return MaisakaEmojiSendResult(
success=False,
@@ -361,4 +367,5 @@ async def send_emoji_for_maisaka(
emotions=emotions,
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
sent_message=sent_message,
)

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
from base64 import b64decode
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from src.chat.utils.utils import process_llm_response
from src.common.data_models.message_component_data_model import EmojiComponent, MessageSequence, TextComponent
@@ -12,10 +12,12 @@ from src.config.config import global_config
from src.core.tooling import ToolExecutionResult
from ..context_messages import SessionBackedMessage
from ..message_adapter import format_speaker_content
from ..message_adapter import build_visible_text_from_sequence, clone_message_sequence, format_speaker_content
from ..planner_message_utils import build_planner_prefix, build_session_backed_text_message
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from ..reasoning_engine import MaisakaReasoningEngine
from ..runtime import MaisakaHeartFlowChatting
@@ -136,6 +138,57 @@ class BuiltinToolRuntimeContext:
return self.engine._get_runtime_manager()
@staticmethod
def _build_visible_text_from_sent_message(message: "SessionMessage") -> str:
"""将已发送消息转换为 Maisaka 可见文本。"""
user_info = message.message_info.user_info
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
visible_message_id = None if message.is_notify else message.message_id
legacy_sequence = MessageSequence([])
legacy_sequence.text(
format_speaker_content(
speaker_name,
"",
message.timestamp,
visible_message_id,
)
)
for component in clone_message_sequence(message.raw_message).components:
legacy_sequence.components.append(component)
return build_visible_text_from_sequence(legacy_sequence).strip()
def append_sent_message_to_chat_history(
self,
message: "SessionMessage",
*,
source_kind: str = "guided_reply",
) -> None:
"""将真实已发送消息同步到 Maisaka 历史。"""
user_info = message.message_info.user_info
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
planner_prefix = build_planner_prefix(
timestamp=message.timestamp,
user_name=speaker_name,
group_card=user_info.user_cardname or "",
message_id=message.message_id,
include_message_id=not message.is_notify and bool(message.message_id),
)
planner_components = clone_message_sequence(message.raw_message).components
if planner_components and isinstance(planner_components[0], TextComponent):
planner_components[0].text = f"{planner_prefix}{planner_components[0].text}"
else:
planner_components.insert(0, TextComponent(planner_prefix))
history_message = SessionBackedMessage.from_session_message(
message,
raw_message=MessageSequence(planner_components),
visible_text=self._build_visible_text_from_sent_message(message),
source_kind=source_kind,
)
self.runtime._chat_history.append(history_message)
def append_guided_reply_to_chat_history(self, reply_text: str) -> None:
"""将引导回复写回 Maisaka 历史。"""

View File

@@ -144,13 +144,14 @@ async def handle_tool(
combined_reply_text = "".join(reply_segments)
try:
sent = False
sent_messages = []
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
for segment in reply_segments:
render_cli_message(segment)
sent = True
else:
for index, segment in enumerate(reply_segments):
sent = await send_service.text_to_stream(
sent_message = await send_service.text_to_stream_with_message(
text=segment,
stream_id=tool_ctx.runtime.session_id,
set_reply=set_quote if index == 0 else False,
@@ -158,8 +159,10 @@ async def handle_tool(
selected_expressions=reply_result.selected_expression_ids or None,
typing=index > 0,
)
sent = sent_message is not None
if not sent:
break
sent_messages.append(sent_message)
except Exception:
logger.exception(
f"{tool_ctx.runtime.log_prefix} 发送文字消息时发生异常,目标消息编号={target_message_id}"
@@ -183,7 +186,11 @@ async def handle_tool(
target_user_info = target_message.message_info.user_info
target_user_name = target_user_info.user_cardname or target_user_info.user_nickname or target_user_info.user_id
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
if tool_ctx.runtime.chat_stream.platform == CLI_PLATFORM_NAME:
tool_ctx.append_guided_reply_to_chat_history(combined_reply_text)
else:
for sent_message in sent_messages:
tool_ctx.append_sent_message_to_chat_history(sent_message)
return tool_ctx.build_success_result(
invocation.tool_name,
"回复已生成并发送。",

View File

@@ -346,10 +346,13 @@ async def handle_tool(
f"描述={send_result.description!r} 情绪标签={send_result.emotions} "
f"请求情绪={emotion!r} 命中情绪={send_result.matched_emotion!r}"
)
tool_ctx.append_sent_emoji_to_chat_history(
emoji_base64=send_result.emoji_base64,
success_message=_EMOJI_SUCCESS_MESSAGE,
)
if send_result.sent_message is not None:
tool_ctx.append_sent_message_to_chat_history(send_result.sent_message)
else:
tool_ctx.append_sent_emoji_to_chat_history(
emoji_base64=send_result.emoji_base64,
success_message=_EMOJI_SUCCESS_MESSAGE,
)
structured_result["success"] = True
return tool_ctx.build_success_result(
invocation.tool_name,

View File

@@ -1,4 +1,4 @@
"""
"""
发送服务模块。
统一封装内部模块的出站消息发送逻辑:
@@ -728,7 +728,7 @@ async def _send_via_platform_io(
reply_message_id: Optional[str],
storage_message: bool,
show_log: bool,
) -> bool:
) -> Optional[SessionMessage]:
"""通过 Platform IO 发送消息。
Args:
@@ -753,7 +753,7 @@ async def _send_via_platform_io(
)
if before_send_result.aborted:
logger.info(f"[SendService] 消息 {message.message_id} 在发送前被 Hook 中止")
return False
return None
before_kwargs = before_send_result.kwargs
typing = _coerce_bool(before_kwargs.get("typing"), typing)
@@ -769,13 +769,13 @@ async def _send_via_platform_io(
except Exception as exc:
logger.error(f"[SendService] 准备 Platform IO 发送管线失败: {exc}")
logger.debug(traceback.format_exc())
return False
return None
try:
route_key = platform_io_manager.build_route_key_from_message(message)
except Exception as exc:
logger.warning(f"[SendService] 根据消息构造 Platform IO 路由键失败: {exc}")
return False
return None
try:
await _prepare_message_for_platform_io(
@@ -792,7 +792,7 @@ async def _send_via_platform_io(
except Exception as exc:
logger.error(f"[SendService] Platform IO 发送异常: {exc}")
logger.debug(traceback.format_exc())
return False
return None
sent = bool(delivery_batch.has_success)
if sent:
@@ -823,10 +823,34 @@ async def _send_via_platform_io(
f"(drivers: {', '.join(successful_driver_ids)}) "
f"message={_build_outbound_log_preview(message)}"
)
return True
return message
_log_platform_io_failures(delivery_batch)
return False
return None
async def send_session_message_with_message(
message: SessionMessage,
*,
typing: bool = False,
set_reply: bool = False,
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
) -> Optional[SessionMessage]:
"""统一发送一条内部消息,并返回最终发送成功的消息对象。"""
if not message.message_id:
logger.error("[SendService] 消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_id无法发送")
return await _send_via_platform_io(
message,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
async def send_session_message(
@@ -861,13 +885,16 @@ async def send_session_message(
logger.error("[SendService] 消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_id无法发送")
return await _send_via_platform_io(
message,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
return (
await send_session_message_with_message(
message,
typing=typing,
set_reply=set_reply,
reply_message_id=reply_message_id,
storage_message=storage_message,
show_log=show_log,
)
is not None
)
@@ -882,6 +909,34 @@ async def _send_to_target(
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定目标构建并发送消息,并返回是否发送成功。"""
return (
await _send_to_target_with_message(
message_sequence=message_sequence,
stream_id=stream_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
show_log=show_log,
selected_expressions=selected_expressions,
)
is not None
)
async def _send_to_target_with_message(
message_sequence: MessageSequence,
stream_id: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> Optional[SessionMessage]:
"""向指定目标构建并发送消息。
Args:
@@ -901,7 +956,7 @@ async def _send_to_target(
try:
if set_reply and reply_message is None:
logger.warning("[SendService] 使用引用回复,但未提供回复消息")
return False
return None
if show_log:
logger.debug(f"[SendService] 发送{_describe_message_sequence(message_sequence)}消息到 {stream_id}")
@@ -914,7 +969,7 @@ async def _send_to_target(
selected_expressions=selected_expressions,
)
if outbound_message is None:
return False
return None
after_build_result, outbound_message = await _invoke_send_hook(
"send_service.after_build_message",
@@ -928,7 +983,7 @@ async def _send_to_target(
)
if after_build_result.aborted:
logger.info(f"[SendService] 消息 {outbound_message.message_id} 在构建后被 Hook 中止")
return False
return None
after_build_kwargs = after_build_result.kwargs
typing = _coerce_bool(after_build_kwargs.get("typing"), typing)
@@ -936,7 +991,7 @@ async def _send_to_target(
storage_message = _coerce_bool(after_build_kwargs.get("storage_message"), storage_message)
show_log = _coerce_bool(after_build_kwargs.get("show_log"), show_log)
sent = await send_session_message(
sent_message = await send_session_message_with_message(
outbound_message,
typing=typing,
set_reply=set_reply,
@@ -944,16 +999,38 @@ async def _send_to_target(
storage_message=storage_message,
show_log=show_log,
)
if sent:
if sent_message is not None:
logger.debug(f"[SendService] 成功发送消息到 {stream_id}")
return True
return sent_message
logger.error("[SendService] 发送消息失败")
return False
return None
except Exception as exc:
logger.error(f"[SendService] 发送消息时出错: {exc}")
traceback.print_exc()
return False
return None
async def text_to_stream_with_message(
text: str,
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
storage_message: bool = True,
selected_expressions: Optional[List[int]] = None,
) -> Optional[SessionMessage]:
"""向指定流发送文本消息,并返回发送成功后的消息对象。"""
return await _send_to_target_with_message(
message_sequence=MessageSequence(components=[TextComponent(text=text)]),
stream_id=stream_id,
display_message="",
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
)
async def text_to_stream(
@@ -979,15 +1056,36 @@ async def text_to_stream(
Returns:
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
message_sequence=MessageSequence(components=[TextComponent(text=text)]),
return (
await text_to_stream_with_message(
text=text,
stream_id=stream_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
)
is not None
)
async def emoji_to_stream_with_message(
emoji_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional[MaiMessage] = None,
) -> Optional[SessionMessage]:
"""向指定流发送表情消息,并返回发送成功后的消息对象。"""
return await _send_to_target_with_message(
message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
stream_id=stream_id,
display_message="",
typing=typing,
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
selected_expressions=selected_expressions,
)
@@ -1010,14 +1108,15 @@ async def emoji_to_stream(
Returns:
bool: 发送成功时返回 ``True``。
"""
return await _send_to_target(
message_sequence=_build_message_sequence_from_custom_message("emoji", emoji_base64),
stream_id=stream_id,
display_message="",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
return (
await emoji_to_stream_with_message(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
is not None
)