fix:表情包识别失败问题
This commit is contained in:
74
pytests/image_sys_test/test_image_data_model.py
Normal file
74
pytests/image_sys_test/test_image_data_model.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import io
|
||||
|
||||
from PIL import Image as PILImage
|
||||
import pytest
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji, MaiImage
|
||||
|
||||
|
||||
def _build_test_image_bytes(image_format: str) -> bytes:
|
||||
image = PILImage.new("RGB", (8, 8), color="white")
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format=image_format)
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_format_updates_runtime_path_metadata(tmp_path: Path) -> None:
|
||||
image_bytes = _build_test_image_bytes("JPEG")
|
||||
tmp_file_path = tmp_path / "emoji.tmp"
|
||||
tmp_file_path.write_bytes(image_bytes)
|
||||
|
||||
emoji = MaiEmoji(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
|
||||
assert await emoji.calculate_hash_format() is True
|
||||
assert emoji.image_format == "jpeg"
|
||||
assert emoji.full_path.suffix == ".jpeg"
|
||||
assert emoji.file_name == emoji.full_path.name
|
||||
assert emoji.dir_path == tmp_path.resolve()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_cls", "extra_fields"),
|
||||
[
|
||||
(
|
||||
MaiEmoji,
|
||||
{
|
||||
"description": "",
|
||||
"last_used_time": None,
|
||||
"query_count": 0,
|
||||
"register_time": None,
|
||||
},
|
||||
),
|
||||
(
|
||||
MaiImage,
|
||||
{
|
||||
"description": "",
|
||||
"vlm_processed": False,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_from_db_instance_restores_image_format_from_path(
|
||||
tmp_path: Path,
|
||||
model_cls: type[MaiEmoji] | type[MaiImage],
|
||||
extra_fields: dict[str, object],
|
||||
) -> None:
|
||||
image_path = tmp_path / "cached.png"
|
||||
image_path.write_bytes(_build_test_image_bytes("PNG"))
|
||||
|
||||
record = SimpleNamespace(
|
||||
no_file_flag=False,
|
||||
image_hash="hash",
|
||||
full_path=str(image_path),
|
||||
**extra_fields,
|
||||
)
|
||||
|
||||
image = model_cls.from_db_instance(record)
|
||||
|
||||
assert image.full_path == image_path.resolve()
|
||||
assert image.file_name == image_path.name
|
||||
assert image.image_format == "png"
|
||||
@@ -125,7 +125,7 @@ def setup_mocks(monkeypatch):
|
||||
db_model_mod = _stub_module("src.common.database.database_model")
|
||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager")
|
||||
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
|
||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
from src.chat.replyer import maisaka_generator as legacy_replyer_module
|
||||
from src.chat.replyer import maisaka_generator_multi as multimodal_replyer_module
|
||||
@@ -445,3 +446,33 @@ def test_runtime_build_tool_detail_panels_uses_emotion_prompt_access_panel(monke
|
||||
assert captured["content"] == "emotion prompt link"
|
||||
assert captured["kwargs"]["chat_id"] == "session-emotion"
|
||||
assert captured["kwargs"]["request_kind"] == "emotion"
|
||||
|
||||
|
||||
def test_runtime_render_context_usage_panel_merges_timing_and_planner(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
runtime = object.__new__(MaisakaHeartFlowChatting)
|
||||
runtime.session_id = "session-merged"
|
||||
runtime.session_name = "测试聊天流"
|
||||
runtime._max_context_size = 20
|
||||
|
||||
printed: list[Any] = []
|
||||
monkeypatch.setattr("src.maisaka.runtime.console.print", lambda renderable: printed.append(renderable))
|
||||
|
||||
runtime._render_context_usage_panel(
|
||||
cycle_id=12,
|
||||
timing_selected_history_count=3,
|
||||
timing_prompt_tokens=15,
|
||||
timing_action="continue",
|
||||
timing_response="继续执行",
|
||||
planner_selected_history_count=5,
|
||||
planner_prompt_tokens=42,
|
||||
planner_response="先查询再回复",
|
||||
)
|
||||
|
||||
assert len(printed) == 1
|
||||
outer_panel = printed[0]
|
||||
assert isinstance(outer_panel, Panel)
|
||||
renderables = list(outer_panel.renderable.renderables)
|
||||
assert isinstance(renderables[0], Text)
|
||||
assert "聊天流名称:测试聊天流" in renderables[0].plain
|
||||
assert "聊天流ID:session-merged" in renderables[0].plain
|
||||
assert len(renderables) == 3
|
||||
|
||||
@@ -64,7 +64,7 @@ def test_builtin_hook_catalog_includes_new_business_hooks(monkeypatch: pytest.Mo
|
||||
async def test_send_emoji_for_maisaka_can_be_aborted_by_hook(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""表情包系统应允许在选择前被 Hook 中止。"""
|
||||
|
||||
from src.chat.emoji_system import maisaka_tool
|
||||
from src.emoji_system import maisaka_tool
|
||||
|
||||
fake_manager = _FakeHookManager(
|
||||
{
|
||||
|
||||
@@ -149,7 +149,7 @@ def setup_mocks(monkeypatch):
|
||||
db_model_mod = _stub_module("src.common.database.database_model")
|
||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager")
|
||||
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
|
||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
||||
|
||||
777
scripts/test_tool_call_api_matrix.py
Normal file
777
scripts/test_tool_call_api_matrix.py
Normal file
@@ -0,0 +1,777 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Sequence
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.common.data_models.llm_service_data_models import LLMServiceRequest, LLMServiceResult # noqa: E402
|
||||
from src.config.config import config_manager # noqa: E402
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig # noqa: E402
|
||||
from src.services.llm_service import generate # noqa: E402
|
||||
from src.services.service_task_resolver import get_available_models # noqa: E402
|
||||
|
||||
|
||||
DEFAULT_SKIP_TASKS = {"embedding", "voice"}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeTarget:
|
||||
"""单个待测试模型目标。"""
|
||||
|
||||
task_name: str
|
||||
model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCallScenario:
|
||||
"""工具调用 API 场景定义。"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
prompt: List[Dict[str, Any]]
|
||||
tool_options: List[Dict[str, Any]] | None = None
|
||||
expect_tool_calls: bool | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProbeResult:
|
||||
"""单次 API 探测结果。"""
|
||||
|
||||
task_name: str
|
||||
target_model_name: str
|
||||
actual_model_name: str
|
||||
provider_name: str
|
||||
client_type: str
|
||||
tool_argument_parse_mode: str
|
||||
case_name: str
|
||||
attempt: int
|
||||
success: bool
|
||||
elapsed_seconds: float
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
response_text: str = ""
|
||||
reasoning_text: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def _ensure_utf8_console() -> None:
|
||||
"""尽量将控制台编码切换为 UTF-8。"""
|
||||
try:
|
||||
if hasattr(sys.stdout, "reconfigure"):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
if hasattr(sys.stderr, "reconfigure"):
|
||||
sys.stderr.reconfigure(encoding="utf-8")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _build_function_tool(name: str, description: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""构造 OpenAI 风格 function tool。"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _build_probe_tools() -> List[Dict[str, Any]]:
|
||||
"""构造通用测试工具。"""
|
||||
weather_tool = _build_function_tool(
|
||||
name="lookup_weather",
|
||||
description="查询指定城市天气。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "城市名"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "温度单位",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"include_forecast": {"type": "boolean", "description": "是否包含未来天气"},
|
||||
},
|
||||
"required": ["city", "unit", "include_forecast"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
search_tool = _build_function_tool(
|
||||
name="search_docs",
|
||||
description="搜索内部知识库。",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "搜索关键词"},
|
||||
"top_k": {"type": "integer", "description": "返回条数"},
|
||||
"filters": {
|
||||
"type": "object",
|
||||
"description": "过滤条件",
|
||||
"properties": {
|
||||
"scope": {"type": "string", "description": "搜索范围"},
|
||||
"tag": {"type": "string", "description": "标签"},
|
||||
},
|
||||
"required": ["scope", "tag"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"required": ["query", "top_k", "filters"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
return [weather_tool, search_tool]
|
||||
|
||||
|
||||
def _build_default_scenarios() -> List[ToolCallScenario]:
|
||||
"""构造默认测试场景。"""
|
||||
tools = _build_probe_tools()
|
||||
weather_tool = tools[0]
|
||||
search_tool = tools[1]
|
||||
|
||||
history_tool_call = {
|
||||
"id": "call_hist_weather_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "lookup_weather",
|
||||
"arguments": {
|
||||
"city": "上海",
|
||||
"unit": "celsius",
|
||||
"include_forecast": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
nested_history_tool_call = {
|
||||
"id": "call_hist_search_001",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_docs",
|
||||
"arguments": {
|
||||
"query": "工具调用兼容性",
|
||||
"top_k": 3,
|
||||
"filters": {
|
||||
"scope": "internal",
|
||||
"tag": "tool-call",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return [
|
||||
ToolCallScenario(
|
||||
name="fresh_tool_call",
|
||||
description="首轮普通工具调用请求。",
|
||||
prompt=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"你正在执行工具调用连通性测试。"
|
||||
"如果能调用工具,就优先调用最合适的工具。"
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "请查询上海天气,并使用工具给出参数。",
|
||||
},
|
||||
],
|
||||
tool_options=[weather_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_with_content",
|
||||
description="历史 assistant 同时包含文本和 tool_calls,当前轮不再提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气,再继续回答。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续说,别丢掉上下文。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_assistant_tool_calls_without_content",
|
||||
description="历史 assistant 只有 tool_calls,没有文本内容。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多轮上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先帮我查一下上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{"role": "user", "content": "继续。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_result_followup",
|
||||
description="历史中包含 assistant.tool_calls 与对应 tool 结果消息。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行工具调用闭环兼容性测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查询天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "多云",
|
||||
"temperature_c": 24,
|
||||
"forecast": ["晴", "小雨"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "结合上面的查询结果继续总结。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_multiple_tool_calls_and_results",
|
||||
description="历史中包含多个 tool_calls 与多条 tool 结果。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行多工具上下文兼容性测试。"},
|
||||
{"role": "user", "content": "先查天气,再搜一下工具调用兼容性文档。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我分两步查询。",
|
||||
"tool_calls": [history_tool_call, nested_history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "阴",
|
||||
"temperature_c": 22,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_search_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"items": [
|
||||
"OpenAI 兼容接口的 arguments 常见为 JSON 字符串",
|
||||
"部分 provider 在历史消息回放时兼容性较弱",
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "继续整合上面的两个结果。"},
|
||||
],
|
||||
tool_options=None,
|
||||
expect_tool_calls=None,
|
||||
),
|
||||
ToolCallScenario(
|
||||
name="history_tool_calls_with_current_tools",
|
||||
description="保留历史 tool_calls,同时当前轮仍然提供 tools。",
|
||||
prompt=[
|
||||
{"role": "system", "content": "你正在执行历史 tool_calls 与当前 tools 共存测试。"},
|
||||
{"role": "user", "content": "先查上海天气。"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "我先查天气。",
|
||||
"tool_calls": [history_tool_call],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_hist_weather_001",
|
||||
"content": json.dumps(
|
||||
{
|
||||
"city": "上海",
|
||||
"condition": "晴",
|
||||
"temperature_c": 26,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": "现在再搜一下工具调用兼容性文档。"},
|
||||
],
|
||||
tool_options=[search_tool],
|
||||
expect_tool_calls=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _parse_multi_value_args(raw_values: Sequence[str] | None) -> List[str]:
|
||||
"""解析命令行中的多值参数。"""
|
||||
parsed_values: List[str] = []
|
||||
for raw_value in raw_values or []:
|
||||
for item in str(raw_value).split(","):
|
||||
normalized_item = item.strip()
|
||||
if normalized_item:
|
||||
parsed_values.append(normalized_item)
|
||||
return parsed_values
|
||||
|
||||
|
||||
def _build_model_map() -> Dict[str, ModelInfo]:
|
||||
"""构造模型名到模型配置的映射。"""
|
||||
return {model.name: model for model in config_manager.get_model_config().models}
|
||||
|
||||
|
||||
def _build_provider_map() -> Dict[str, APIProvider]:
|
||||
"""构造 Provider 名称到配置的映射。"""
|
||||
return {provider.name: provider for provider in config_manager.get_model_config().api_providers}
|
||||
|
||||
|
||||
def _pick_default_task_name(task_names: Sequence[str]) -> str:
|
||||
"""选择默认任务名。"""
|
||||
if "utils" in task_names:
|
||||
return "utils"
|
||||
if not task_names:
|
||||
raise ValueError("当前没有可用的任务配置")
|
||||
return str(task_names[0])
|
||||
|
||||
|
||||
def _resolve_targets(task_filters: Sequence[str], model_filters: Sequence[str], fallback_task: str) -> List[ProbeTarget]:
|
||||
"""根据命令行参数解析待测试目标。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
provider_map = _build_provider_map()
|
||||
|
||||
if not available_tasks:
|
||||
raise ValueError("未找到任何可用的模型任务配置")
|
||||
|
||||
if task_filters:
|
||||
selected_task_names = []
|
||||
for task_name in task_filters:
|
||||
if task_name not in available_tasks:
|
||||
raise ValueError(f"未找到任务 `{task_name}`")
|
||||
selected_task_names.append(task_name)
|
||||
else:
|
||||
selected_task_names = [
|
||||
task_name
|
||||
for task_name in available_tasks
|
||||
if task_name not in DEFAULT_SKIP_TASKS
|
||||
]
|
||||
|
||||
if not selected_task_names:
|
||||
raise ValueError("没有可用于工具调用 API 测试的任务,请显式通过 --task 指定")
|
||||
|
||||
default_task_name = fallback_task if fallback_task in available_tasks else _pick_default_task_name(selected_task_names)
|
||||
resolved_targets: List[ProbeTarget] = []
|
||||
seen_models: set[str] = set()
|
||||
|
||||
if model_filters:
|
||||
model_names = list(model_filters)
|
||||
else:
|
||||
model_names = []
|
||||
for task_name in selected_task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
for model_name in task_config.model_list:
|
||||
if model_name not in model_names:
|
||||
model_names.append(model_name)
|
||||
|
||||
for model_name in model_names:
|
||||
if model_name in seen_models:
|
||||
continue
|
||||
if model_name not in model_map:
|
||||
raise ValueError(f"未找到模型 `{model_name}`")
|
||||
|
||||
target_task_name = ""
|
||||
for task_name in selected_task_names:
|
||||
if model_name in available_tasks[task_name].model_list:
|
||||
target_task_name = task_name
|
||||
break
|
||||
if not target_task_name:
|
||||
target_task_name = default_task_name
|
||||
|
||||
model_info = model_map[model_name]
|
||||
provider_info = provider_map[model_info.api_provider]
|
||||
resolved_targets.append(
|
||||
ProbeTarget(
|
||||
task_name=target_task_name,
|
||||
model_name=model_name,
|
||||
provider_name=provider_info.name,
|
||||
client_type=provider_info.client_type,
|
||||
tool_argument_parse_mode=provider_info.tool_argument_parse_mode,
|
||||
)
|
||||
)
|
||||
seen_models.add(model_name)
|
||||
|
||||
return resolved_targets
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _pin_task_to_model(task_name: str, model_name: str) -> Iterator[None]:
|
||||
"""临时将某个任务锁定到单模型。"""
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
task_config = getattr(model_task_config, task_name, None)
|
||||
if not isinstance(task_config, TaskConfig):
|
||||
raise ValueError(f"未找到任务 `{task_name}` 对应的配置")
|
||||
|
||||
original_model_list = list(task_config.model_list)
|
||||
original_selection_strategy = task_config.selection_strategy
|
||||
task_config.model_list = [model_name]
|
||||
task_config.selection_strategy = "balance"
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
task_config.model_list = original_model_list
|
||||
task_config.selection_strategy = original_selection_strategy
|
||||
|
||||
|
||||
def _serialize_tool_calls(tool_calls: Any) -> List[Dict[str, Any]]:
|
||||
"""序列化返回中的工具调用。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
serialized_items: List[Dict[str, Any]] = []
|
||||
for tool_call in tool_calls:
|
||||
serialized_items.append(
|
||||
{
|
||||
"id": getattr(tool_call, "call_id", ""),
|
||||
"function": {
|
||||
"name": getattr(tool_call, "func_name", ""),
|
||||
"arguments": dict(getattr(tool_call, "args", {}) or {}),
|
||||
},
|
||||
**(
|
||||
{"extra_content": dict(getattr(tool_call, "extra_content", {}) or {})}
|
||||
if getattr(tool_call, "extra_content", None)
|
||||
else {}
|
||||
),
|
||||
}
|
||||
)
|
||||
return serialized_items
|
||||
|
||||
|
||||
def _validate_service_result(service_result: LLMServiceResult, scenario: ToolCallScenario) -> tuple[List[str], List[str], List[Dict[str, Any]]]:
|
||||
"""校验服务结果。"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
completion = service_result.completion
|
||||
serialized_tool_calls = _serialize_tool_calls(completion.tool_calls)
|
||||
|
||||
if not service_result.success:
|
||||
errors.append(service_result.error or completion.response or "请求失败,但没有返回明确错误")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
if scenario.expect_tool_calls is True and not serialized_tool_calls:
|
||||
warnings.append("本场景期望模型倾向于调用工具,但未返回 tool_calls")
|
||||
if scenario.expect_tool_calls is False and serialized_tool_calls:
|
||||
warnings.append("本场景未期望继续调用工具,但模型返回了 tool_calls")
|
||||
if completion.response.strip():
|
||||
warnings.append("模型返回了可见文本")
|
||||
return errors, warnings, serialized_tool_calls
|
||||
|
||||
|
||||
async def _run_single_probe(
|
||||
target: ProbeTarget,
|
||||
scenario: ToolCallScenario,
|
||||
attempt: int,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
) -> ProbeResult:
|
||||
"""执行单次 API 探测。"""
|
||||
request = LLMServiceRequest(
|
||||
task_name=target.task_name,
|
||||
request_type=f"tool_call_api_matrix.{scenario.name}.attempt_{attempt}",
|
||||
prompt=scenario.prompt,
|
||||
tool_options=scenario.tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
started_at = time.perf_counter()
|
||||
with _pin_task_to_model(target.task_name, target.model_name):
|
||||
service_result = await generate(request)
|
||||
elapsed_seconds = time.perf_counter() - started_at
|
||||
|
||||
errors, warnings, serialized_tool_calls = _validate_service_result(service_result, scenario)
|
||||
completion = service_result.completion
|
||||
return ProbeResult(
|
||||
task_name=target.task_name,
|
||||
target_model_name=target.model_name,
|
||||
actual_model_name=completion.model_name,
|
||||
provider_name=target.provider_name,
|
||||
client_type=target.client_type,
|
||||
tool_argument_parse_mode=target.tool_argument_parse_mode,
|
||||
case_name=scenario.name,
|
||||
attempt=attempt,
|
||||
success=not errors,
|
||||
elapsed_seconds=elapsed_seconds,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
response_text=completion.response,
|
||||
reasoning_text=completion.reasoning,
|
||||
tool_calls=serialized_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
def _print_targets(targets: Sequence[ProbeTarget]) -> None:
|
||||
"""打印待测试目标。"""
|
||||
print("待测试目标:")
|
||||
for index, target in enumerate(targets, start=1):
|
||||
print(
|
||||
f"{index}. model={target.model_name} | task={target.task_name} | "
|
||||
f"provider={target.provider_name} | client={target.client_type} | "
|
||||
f"tool_argument_parse_mode={target.tool_argument_parse_mode}"
|
||||
)
|
||||
|
||||
|
||||
def _print_available_targets() -> None:
|
||||
"""打印当前可用任务与模型。"""
|
||||
available_tasks = get_available_models()
|
||||
model_map = _build_model_map()
|
||||
task_names = list(available_tasks.keys())
|
||||
|
||||
print("当前可用任务:")
|
||||
for task_name in task_names:
|
||||
task_config = available_tasks[task_name]
|
||||
print(f"- {task_name}: {list(task_config.model_list)}")
|
||||
|
||||
referenced_models = {
|
||||
model_name
|
||||
for task_config in available_tasks.values()
|
||||
for model_name in task_config.model_list
|
||||
}
|
||||
|
||||
print("\n当前配置中的模型:")
|
||||
for model_name, model_info in model_map.items():
|
||||
referenced_mark = "已被任务引用" if model_name in referenced_models else "未被任务引用"
|
||||
print(
|
||||
f"- {model_name}: provider={model_info.api_provider}, "
|
||||
f"identifier={model_info.model_identifier}, {referenced_mark}"
|
||||
)
|
||||
|
||||
|
||||
def _select_scenarios(case_filters: Sequence[str]) -> List[ToolCallScenario]:
|
||||
"""按名称筛选测试场景。"""
|
||||
all_scenarios = {scenario.name: scenario for scenario in _build_default_scenarios()}
|
||||
if not case_filters:
|
||||
return list(all_scenarios.values())
|
||||
|
||||
selected_scenarios: List[ToolCallScenario] = []
|
||||
for case_name in case_filters:
|
||||
if case_name not in all_scenarios:
|
||||
raise ValueError(
|
||||
f"未知测试场景 `{case_name}`,可选值: {', '.join(sorted(all_scenarios))}"
|
||||
)
|
||||
selected_scenarios.append(all_scenarios[case_name])
|
||||
return selected_scenarios
|
||||
|
||||
|
||||
def _print_single_result(result: ProbeResult, show_response: bool) -> None:
|
||||
"""打印单次结果。"""
|
||||
status_text = "PASS" if result.success else "FAIL"
|
||||
print(
|
||||
f"[{status_text}] model={result.target_model_name} | task={result.task_name} | "
|
||||
f"case={result.case_name} | attempt={result.attempt} | elapsed={result.elapsed_seconds:.2f}s"
|
||||
)
|
||||
if result.errors:
|
||||
for error in result.errors:
|
||||
print(f" ERROR: {error}")
|
||||
if result.warnings:
|
||||
for warning in result.warnings:
|
||||
print(f" WARN: {warning}")
|
||||
if result.tool_calls:
|
||||
print(f" tool_calls: {json.dumps(result.tool_calls, ensure_ascii=False)}")
|
||||
if show_response and result.response_text.strip():
|
||||
print(f" response: {result.response_text}")
|
||||
|
||||
|
||||
def _build_summary(results: Sequence[ProbeResult]) -> Dict[str, Any]:
|
||||
"""构造结果摘要。"""
|
||||
total_count = len(results)
|
||||
passed_count = sum(1 for result in results if result.success)
|
||||
failed_count = total_count - passed_count
|
||||
failed_items = [
|
||||
{
|
||||
"model_name": result.target_model_name,
|
||||
"case_name": result.case_name,
|
||||
"attempt": result.attempt,
|
||||
"errors": list(result.errors),
|
||||
}
|
||||
for result in results
|
||||
if not result.success
|
||||
]
|
||||
return {
|
||||
"total": total_count,
|
||||
"passed": passed_count,
|
||||
"failed": failed_count,
|
||||
"failed_items": failed_items,
|
||||
}
|
||||
|
||||
|
||||
def _write_json_report(json_out: str, results: Sequence[ProbeResult]) -> None:
|
||||
"""将测试结果写入 JSON 文件。"""
|
||||
output_path = Path(json_out).expanduser().resolve()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = {
|
||||
"generated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"summary": _build_summary(results),
|
||||
"results": [asdict(result) for result in results],
|
||||
}
|
||||
output_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f"\n结果已写入: {output_path}")
|
||||
|
||||
|
||||
async def _run_probes(args: Namespace) -> List[ProbeResult]:
|
||||
"""执行所有探测请求。"""
|
||||
task_filters = _parse_multi_value_args(args.task)
|
||||
model_filters = _parse_multi_value_args(args.model)
|
||||
case_filters = _parse_multi_value_args(args.case)
|
||||
|
||||
selected_scenarios = _select_scenarios(case_filters)
|
||||
targets = _resolve_targets(task_filters, model_filters, args.fallback_task)
|
||||
|
||||
_print_targets(targets)
|
||||
print("")
|
||||
|
||||
results: List[ProbeResult] = []
|
||||
for target in targets:
|
||||
for attempt in range(1, args.repeat + 1):
|
||||
for scenario in selected_scenarios:
|
||||
print(
|
||||
f"开始测试: model={target.model_name}, task={target.task_name}, "
|
||||
f"case={scenario.name}, attempt={attempt}"
|
||||
)
|
||||
result = await _run_single_probe(
|
||||
target=target,
|
||||
scenario=scenario,
|
||||
attempt=attempt,
|
||||
max_tokens=args.max_tokens,
|
||||
temperature=args.temperature,
|
||||
)
|
||||
_print_single_result(result, args.show_response)
|
||||
print("")
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _build_parser() -> ArgumentParser:
|
||||
"""构造命令行参数解析器。"""
|
||||
parser = ArgumentParser(
|
||||
description=(
|
||||
"测试不同模型在多种工具调用消息形态下的 API 兼容性。\n"
|
||||
"重点覆盖历史 assistant.tool_calls、tool 结果消息、多工具调用等场景。"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
action="append",
|
||||
help="指定任务名,可重复传入,或使用逗号分隔多个值,例如 --task utils --task planner",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
action="append",
|
||||
help="指定模型名,可重复传入,或使用逗号分隔多个值,例如 --model qwen3.5-35b-a3b",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--case",
|
||||
action="append",
|
||||
help=(
|
||||
"指定测试场景名,可选值包括 "
|
||||
"fresh_tool_call、history_assistant_tool_calls_with_content、"
|
||||
"history_assistant_tool_calls_without_content、history_tool_result_followup、"
|
||||
"history_multiple_tool_calls_and_results、history_tool_calls_with_current_tools"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=1,
|
||||
help="每个模型每个场景重复测试次数,默认 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
default=512,
|
||||
help="单次测试的最大输出 token 数,默认 512",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="单次测试温度,默认 0.0,以尽量提高稳定性",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fallback-task",
|
||||
default="utils",
|
||||
help="当指定模型未被已选任务引用时,用于挂载该模型的任务名,默认 utils",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json-out",
|
||||
help="可选,将结果写入指定 JSON 文件",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-targets",
|
||||
action="store_true",
|
||||
help="仅打印当前任务与模型映射,不发起网络请求",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-response",
|
||||
action="store_true",
|
||||
help="打印模型返回的可见文本内容",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
_ensure_utf8_console()
|
||||
config_manager.initialize()
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.repeat < 1:
|
||||
parser.error("--repeat 必须大于等于 1")
|
||||
if args.max_tokens < 1:
|
||||
parser.error("--max-tokens 必须大于等于 1")
|
||||
|
||||
if args.list_targets:
|
||||
_print_available_targets()
|
||||
return 0
|
||||
|
||||
results = asyncio.run(_run_probes(args))
|
||||
summary = _build_summary(results)
|
||||
|
||||
print("测试摘要:")
|
||||
print(
|
||||
f"total={summary['total']} | passed={summary['passed']} | failed={summary['failed']}"
|
||||
)
|
||||
if summary["failed_items"]:
|
||||
print("失败明细:")
|
||||
for failed_item in summary["failed_items"]:
|
||||
print(
|
||||
f"- model={failed_item['model_name']} | case={failed_item['case_name']} | "
|
||||
f"attempt={failed_item['attempt']} | errors={failed_item['errors']}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
_write_json_report(args.json_out, results)
|
||||
|
||||
return 0 if summary["failed"] == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -3,7 +3,7 @@ MaiBot模块系统
|
||||
包含聊天、情绪、记忆、日程等功能模块
|
||||
"""
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
|
||||
@@ -296,6 +296,8 @@ class ImageManager:
|
||||
async def build_image_description(self, image_bytes: bytes) -> MaiImage:
|
||||
"""在图片已保存的前提下生成或补齐图片描述。"""
|
||||
mai_image = await self.ensure_image_saved(image_bytes)
|
||||
if not mai_image.image_format:
|
||||
await mai_image.calculate_hash_format()
|
||||
if mai_image.vlm_processed and mai_image.description:
|
||||
return mai_image
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ class SessionMessage(MaiMessage):
|
||||
"""
|
||||
if component.content: # 先检查是否处理过
|
||||
return component.content
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
# 获取表情包描述
|
||||
try:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from PIL import Image as PILImage
|
||||
from rich.traceback import install
|
||||
@@ -28,15 +29,27 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
if Path(full_path).is_dir() or not Path(full_path).exists():
|
||||
raise FileNotFoundError(f"表情包路径无效: {full_path}")
|
||||
resolved_path = Path(full_path).absolute().resolve()
|
||||
self.full_path: Path = resolved_path
|
||||
self.dir_path: Path = resolved_path.parent.resolve()
|
||||
self.file_name: str = resolved_path.name
|
||||
self.full_path: Path
|
||||
self.dir_path: Path
|
||||
self.file_name: str
|
||||
self._set_full_path(resolved_path)
|
||||
self.file_hash: str = None # type: ignore
|
||||
|
||||
self.image_bytes: Optional[bytes] = image_bytes
|
||||
|
||||
self.image_format: str = "" # 图片格式
|
||||
|
||||
def _set_full_path(self, full_path: Path) -> None:
|
||||
"""同步更新文件路径相关的运行时元数据。"""
|
||||
resolved_path = full_path.absolute().resolve()
|
||||
self.full_path = resolved_path
|
||||
self.dir_path = resolved_path.parent.resolve()
|
||||
self.file_name = resolved_path.name
|
||||
|
||||
def _restore_image_format_from_path(self) -> None:
|
||||
"""根据文件扩展名恢复基础图片格式信息。"""
|
||||
self.image_format = self.full_path.suffix.removeprefix(".").lower()
|
||||
|
||||
def read_image_bytes(self, path: Path) -> bytes:
|
||||
"""
|
||||
同步读取图片文件的字节内容
|
||||
@@ -97,6 +110,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
image_bytes = await asyncio.to_thread(self.read_image_bytes, self.full_path)
|
||||
else:
|
||||
image_bytes = self.image_bytes
|
||||
self.image_bytes = image_bytes
|
||||
self.file_hash = hashlib.sha256(image_bytes).hexdigest()
|
||||
logger.debug(f"[初始化] {self.file_name} 计算哈希值成功: {self.file_hash}")
|
||||
|
||||
@@ -115,7 +129,7 @@ class BaseImageDataModel(BaseDatabaseDataModel[Images]):
|
||||
new_file_name = ".".join(self.file_name.split(".")[:-1] + [self.image_format])
|
||||
new_full_path = self.dir_path / new_file_name
|
||||
self.full_path.rename(new_full_path)
|
||||
self.full_path = new_full_path
|
||||
self._set_full_path(new_full_path)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -153,6 +167,7 @@ class MaiEmoji(BaseImageDataModel):
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj._restore_image_format_from_path()
|
||||
description = db_record.description or ""
|
||||
obj.description = description
|
||||
normalized_tags = [
|
||||
@@ -207,7 +222,8 @@ class MaiImage(BaseImageDataModel):
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiImage 对象")
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj.full_path = Path(db_record.full_path)
|
||||
obj._set_full_path(Path(db_record.full_path))
|
||||
obj._restore_image_format_from_path()
|
||||
obj.description = db_record.description
|
||||
obj.vlm_processed = db_record.vlm_processed
|
||||
return obj
|
||||
|
||||
@@ -826,7 +826,7 @@ class EmojiManager:
|
||||
Returns:
|
||||
return (Tuple[bool, MaiEmoji]): 返回是否成功构建描述,及表情包对象
|
||||
"""
|
||||
if not target_emoji.file_hash:
|
||||
if not target_emoji.file_hash or not target_emoji.image_format:
|
||||
# Should not happen, but just in case
|
||||
await target_emoji.calculate_hash_format()
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
|
||||
from src.A_memorix.host_service import a_memorix_host_service
|
||||
from src.learners.expression_auto_check_task import ExpressionAutoCheckTask
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
|
||||
@@ -12,8 +12,8 @@ from PIL import Image as PILImage
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from pydantic import BaseModel, Field as PydanticField
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.maisaka_tool import send_emoji_for_maisaka
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.message_component_data_model import ImageComponent, MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -352,6 +352,7 @@ class MaisakaReasoningEngine:
|
||||
timing_response: Optional[ChatResponse] = None
|
||||
timing_tool_results: Optional[list[str]] = None
|
||||
response: Optional[ChatResponse] = None
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results: list[dict[str, Any]] = []
|
||||
try:
|
||||
visual_refresh_started_at = time.time()
|
||||
@@ -377,14 +378,6 @@ class MaisakaReasoningEngine:
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
duration_ms=timing_duration_ms,
|
||||
)
|
||||
self._runtime._render_context_usage_panel(
|
||||
selected_history_count=timing_response.selected_history_count,
|
||||
prompt_tokens=timing_response.prompt_tokens,
|
||||
planner_response=timing_response.content or "",
|
||||
tool_calls=timing_response.tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
prompt_section=timing_response.prompt_section,
|
||||
)
|
||||
if timing_action != "continue":
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} Timing Gate 结束当前回合: "
|
||||
@@ -418,7 +411,6 @@ class MaisakaReasoningEngine:
|
||||
|
||||
self._last_reasoning_content = reasoning_content
|
||||
self._runtime._chat_history.append(response.raw_message)
|
||||
tool_result_summaries: list[str] = []
|
||||
tool_monitor_results = []
|
||||
|
||||
if response.tool_calls:
|
||||
@@ -429,25 +421,10 @@ class MaisakaReasoningEngine:
|
||||
anchor_message,
|
||||
)
|
||||
cycle_detail.time_records["tool_calls"] = time.time() - tool_started_at
|
||||
self._runtime._render_context_usage_panel(
|
||||
selected_history_count=response.selected_history_count,
|
||||
prompt_tokens=response.prompt_tokens,
|
||||
planner_response=response.content or "",
|
||||
tool_calls=response.tool_calls,
|
||||
tool_results=tool_result_summaries,
|
||||
tool_detail_results=tool_monitor_results,
|
||||
prompt_section=response.prompt_section,
|
||||
)
|
||||
if should_pause:
|
||||
break
|
||||
continue
|
||||
|
||||
self._runtime._render_context_usage_panel(
|
||||
selected_history_count=response.selected_history_count,
|
||||
prompt_tokens=response.prompt_tokens,
|
||||
planner_response=response.content or "",
|
||||
prompt_section=response.prompt_section,
|
||||
)
|
||||
if not response.content:
|
||||
break
|
||||
except ReqAbortException:
|
||||
@@ -462,6 +439,31 @@ class MaisakaReasoningEngine:
|
||||
break
|
||||
finally:
|
||||
completed_cycle = self._end_cycle(cycle_detail)
|
||||
self._runtime._render_context_usage_panel(
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
timing_selected_history_count=(
|
||||
timing_response.selected_history_count if timing_response is not None else None
|
||||
),
|
||||
timing_prompt_tokens=(
|
||||
timing_response.prompt_tokens if timing_response is not None else None
|
||||
),
|
||||
timing_action=timing_action or "",
|
||||
timing_response=timing_response.content or "" if timing_response is not None else "",
|
||||
timing_tool_calls=timing_response.tool_calls if timing_response is not None else None,
|
||||
timing_tool_results=timing_tool_results,
|
||||
timing_prompt_section=(
|
||||
timing_response.prompt_section if timing_response is not None else None
|
||||
),
|
||||
planner_selected_history_count=(
|
||||
response.selected_history_count if response is not None else None
|
||||
),
|
||||
planner_prompt_tokens=response.prompt_tokens if response is not None else None,
|
||||
planner_response=response.content or "" if response is not None else "",
|
||||
planner_tool_calls=response.tool_calls if response is not None else None,
|
||||
planner_tool_results=tool_result_summaries,
|
||||
planner_tool_detail_results=tool_monitor_results,
|
||||
planner_prompt_section=response.prompt_section if response is not None else None,
|
||||
)
|
||||
await emit_planner_finalized(
|
||||
session_id=self._runtime.session_id,
|
||||
cycle_id=cycle_detail.cycle_id,
|
||||
|
||||
@@ -58,6 +58,7 @@ class MaisakaHeartFlowChatting:
|
||||
self.chat_stream: BotChatSession = chat_stream
|
||||
|
||||
session_name = chat_manager.get_session_name(session_id) or session_id
|
||||
self.session_name = session_name
|
||||
self.log_prefix = f"[{session_name}]"
|
||||
self._chat_loop_service = MaisakaChatLoopService(
|
||||
session_id=session_id,
|
||||
@@ -692,28 +693,117 @@ class MaisakaHeartFlowChatting:
|
||||
def _render_context_usage_panel(
|
||||
self,
|
||||
*,
|
||||
selected_history_count: int,
|
||||
prompt_tokens: int,
|
||||
cycle_id: Optional[int] = None,
|
||||
timing_selected_history_count: Optional[int] = None,
|
||||
timing_prompt_tokens: Optional[int] = None,
|
||||
timing_action: str = "",
|
||||
timing_response: str = "",
|
||||
timing_tool_calls: Optional[list[Any]] = None,
|
||||
timing_tool_results: Optional[list[str]] = None,
|
||||
timing_tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
timing_prompt_section: Optional[RenderableType] = None,
|
||||
planner_selected_history_count: Optional[int] = None,
|
||||
planner_prompt_tokens: Optional[int] = None,
|
||||
planner_response: str = "",
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
prompt_section: Optional[RenderableType] = None,
|
||||
planner_tool_calls: Optional[list[Any]] = None,
|
||||
planner_tool_results: Optional[list[str]] = None,
|
||||
planner_tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
planner_prompt_section: Optional[RenderableType] = None,
|
||||
) -> None:
|
||||
"""在终端展示当前聊天流的上下文占用、规划结果与工具结果。"""
|
||||
"""在终端展示当前聊天流本轮 cycle 的最终结果。"""
|
||||
if not global_config.debug.show_maisaka_thinking:
|
||||
return
|
||||
|
||||
body_lines = [
|
||||
f"上下文占用:{selected_history_count}/{self._max_context_size} 条",
|
||||
f"本次请求token消耗:{format_token_count(prompt_tokens)}",
|
||||
f"聊天流名称:{getattr(self, 'session_name', self.session_id)}",
|
||||
f"聊天流ID:{self.session_id}",
|
||||
]
|
||||
if cycle_id is not None:
|
||||
body_lines.append(f"循环编号:{cycle_id}")
|
||||
|
||||
renderables: list[RenderableType] = [Text("\n".join(body_lines))]
|
||||
timing_panel = self._build_cycle_stage_panel(
|
||||
title="Timing Gate",
|
||||
border_style="bright_magenta",
|
||||
selected_history_count=timing_selected_history_count,
|
||||
prompt_tokens=timing_prompt_tokens,
|
||||
response_text=timing_response,
|
||||
tool_calls=timing_tool_calls,
|
||||
tool_results=timing_tool_results,
|
||||
tool_detail_results=timing_tool_detail_results,
|
||||
prompt_section=timing_prompt_section,
|
||||
extra_lines=[f"门控动作:{timing_action}"] if timing_action.strip() else None,
|
||||
)
|
||||
if timing_panel is not None:
|
||||
renderables.append(timing_panel)
|
||||
|
||||
planner_panel = self._build_cycle_stage_panel(
|
||||
title="Planner",
|
||||
border_style="green",
|
||||
selected_history_count=planner_selected_history_count,
|
||||
prompt_tokens=planner_prompt_tokens,
|
||||
response_text=planner_response,
|
||||
tool_calls=planner_tool_calls,
|
||||
tool_results=planner_tool_results,
|
||||
tool_detail_results=planner_tool_detail_results,
|
||||
prompt_section=planner_prompt_section,
|
||||
)
|
||||
if planner_panel is not None:
|
||||
renderables.append(planner_panel)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
title="MaiSaka 循环",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
def _build_cycle_stage_panel(
|
||||
self,
|
||||
*,
|
||||
title: str,
|
||||
border_style: str,
|
||||
selected_history_count: Optional[int],
|
||||
prompt_tokens: Optional[int],
|
||||
response_text: str = "",
|
||||
tool_calls: Optional[list[Any]] = None,
|
||||
tool_results: Optional[list[str]] = None,
|
||||
tool_detail_results: Optional[list[dict[str, Any]]] = None,
|
||||
prompt_section: Optional[RenderableType] = None,
|
||||
extra_lines: Optional[list[str]] = None,
|
||||
) -> Optional[Panel]:
|
||||
"""构建单个 cycle 阶段的展示卡片。"""
|
||||
|
||||
has_content = any([
|
||||
selected_history_count is not None,
|
||||
prompt_tokens is not None,
|
||||
bool(response_text.strip()),
|
||||
bool(tool_calls),
|
||||
bool(tool_results),
|
||||
bool(tool_detail_results),
|
||||
prompt_section is not None,
|
||||
bool(extra_lines),
|
||||
])
|
||||
if not has_content:
|
||||
return None
|
||||
|
||||
body_lines: list[str] = []
|
||||
if selected_history_count is not None:
|
||||
body_lines.append(f"上下文占用:{selected_history_count}/{self._max_context_size} 条")
|
||||
if prompt_tokens is not None:
|
||||
body_lines.append(f"本次请求token消耗:{format_token_count(prompt_tokens)}")
|
||||
if extra_lines:
|
||||
body_lines.extend([line for line in extra_lines if isinstance(line, str) and line.strip()])
|
||||
|
||||
renderables: list[RenderableType] = []
|
||||
if body_lines:
|
||||
renderables.append(Text("\n".join(body_lines)))
|
||||
if prompt_section is not None:
|
||||
renderables.append(prompt_section)
|
||||
|
||||
normalized_response = planner_response.strip()
|
||||
normalized_response = response_text.strip()
|
||||
if normalized_response:
|
||||
renderables.append(
|
||||
Panel(
|
||||
@@ -753,13 +843,11 @@ class MaisakaHeartFlowChatting:
|
||||
if detail_panels:
|
||||
renderables.extend(detail_panels)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
title="MaiSaka 上下文与结果",
|
||||
border_style="bright_blue",
|
||||
padding=(0, 1),
|
||||
)
|
||||
return Panel(
|
||||
Group(*renderables),
|
||||
title=title,
|
||||
border_style=border_style,
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -63,7 +63,7 @@ class RuntimeDataCapabilityMixin:
|
||||
|
||||
@staticmethod
|
||||
def _build_emoji_temp_path() -> Path:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.emoji_system.emoji_manager import EMOJI_DIR
|
||||
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return EMOJI_DIR / f"emoji_cap_{int(time.time() * 1000000)}.png"
|
||||
@@ -463,7 +463,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_emoji_get_by_description(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
description: str = args.get("description", "")
|
||||
if not description:
|
||||
@@ -485,7 +485,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_emoji_get_random(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
count: int = args.get("count", 1)
|
||||
try:
|
||||
@@ -512,7 +512,7 @@ class RuntimeDataCapabilityMixin:
|
||||
|
||||
async def _cap_emoji_get_count(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
return {"success": True, "count": len(emoji_manager.emojis)}
|
||||
except Exception as e:
|
||||
@@ -521,7 +521,7 @@ class RuntimeDataCapabilityMixin:
|
||||
|
||||
async def _cap_emoji_get_emotions(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
emotions = sorted(
|
||||
{
|
||||
@@ -540,7 +540,7 @@ class RuntimeDataCapabilityMixin:
|
||||
|
||||
async def _cap_emoji_get_all(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
emojis = []
|
||||
for emoji in emoji_manager.emojis:
|
||||
@@ -556,7 +556,7 @@ class RuntimeDataCapabilityMixin:
|
||||
|
||||
async def _cap_emoji_get_info(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
from src.config.config import global_config
|
||||
|
||||
current_count = len(emoji_manager.emojis)
|
||||
@@ -573,7 +573,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_emoji_register(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
emoji_base64: str = args.get("emoji_base64", "")
|
||||
if not emoji_base64:
|
||||
@@ -630,7 +630,7 @@ class RuntimeDataCapabilityMixin:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def _cap_emoji_delete(self, plugin_id: str, capability: str, args: Dict[str, Any]) -> Any:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
emoji_hash: str = args.get("emoji_hash", "")
|
||||
if not emoji_hash:
|
||||
|
||||
@@ -20,7 +20,7 @@ def _get_builtin_hook_spec_registrars() -> List[HookSpecRegistrar]:
|
||||
"""
|
||||
|
||||
from src.chat.message_receive.bot import register_chat_hook_specs
|
||||
from src.chat.emoji_system.emoji_manager import register_emoji_hook_specs
|
||||
from src.emoji_system.emoji_manager import register_emoji_hook_specs
|
||||
from src.learners.expression_learner import register_expression_hook_specs
|
||||
from src.learners.jargon_miner import register_jargon_hook_specs
|
||||
from src.maisaka.chat_loop_service import register_maisaka_hook_specs
|
||||
|
||||
Reference in New Issue
Block a user