feat:统一planner的模式

This commit is contained in:
SengokuCola
2026-04-12 16:35:13 +08:00
parent fc08e44293
commit 412166ed7e
4 changed files with 158 additions and 49 deletions

View File

@@ -0,0 +1,58 @@
"""Maisaka Prompt 预览路径工具。"""
from __future__ import annotations
from pathlib import Path
from urllib.parse import quote
import re
from src.chat.message_receive.chat_manager import chat_manager
REPO_ROOT = Path(__file__).parent.parent.parent.parent.absolute().resolve()
SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
def normalize_preview_name(value: str) -> str:
normalized_value = SAFE_NAME_PATTERN.sub("_", str(value or "").strip()).strip("._")
if normalized_value:
return normalized_value
return "unknown"
def normalize_platform_name(platform: str) -> str:
normalized_platform = str(platform or "").strip().lower()
platform_aliases = {
"telegram": "tg",
}
return normalize_preview_name(platform_aliases.get(normalized_platform, normalized_platform))
def build_preview_chat_dir_name(chat_id: str) -> str:
session = chat_manager.get_session_by_session_id(chat_id)
if session is not None:
platform = normalize_platform_name(session.platform)
if session.is_group_session and session.group_id:
return f"{platform}_group_{normalize_preview_name(session.group_id)}"
if session.user_id:
return f"{platform}_private_{normalize_preview_name(session.user_id)}"
normalized_chat_id = normalize_preview_name(chat_id)
if normalized_chat_id != "unknown":
return normalized_chat_id
return "unknown_chat"
def build_display_path(file_path: Path) -> str:
"""构造用于展示的路径,项目内文件优先显示相对路径。"""
resolved_path = file_path.resolve()
try:
return resolved_path.relative_to(REPO_ROOT).as_posix()
except ValueError:
return resolved_path.as_posix()
def build_file_uri(file_path: Path) -> str:
normalized = file_path.resolve().as_posix()
return f"file:///{quote(normalized, safe='/:')}"

View File

@@ -7,7 +7,6 @@ from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal
from urllib.parse import quote
import hashlib
import html
@@ -27,10 +26,10 @@ from .display_utils import (
get_role_badge_label as get_shared_role_badge_label,
get_role_badge_style as get_shared_role_badge_style,
)
from .preview_path_utils import build_display_path, build_file_uri, REPO_ROOT
from .prompt_preview_logger import PromptPreviewLogger
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
DATA_IMAGE_DIR = PROJECT_ROOT / "data" / "images"
DATA_IMAGE_DIR = REPO_ROOT / "data" / "images"
class PromptImageDisplayMode(str, Enum):
@@ -115,11 +114,6 @@ class PromptCLIVisualizer:
digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest()
return root / f"{digest}.{image_format}"
@staticmethod
def _build_file_uri(file_path: Path) -> str:
normalized = file_path.resolve().as_posix()
return f"file:///{quote(normalized, safe='/:')}"
@staticmethod
def _build_official_image_path(image_format: str, image_base64: str) -> Path | None:
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format)
@@ -140,7 +134,7 @@ class PromptCLIVisualizer:
normalized_format = PromptCLIVisualizer._normalize_image_format(image_format) or "bin"
official_path = PromptCLIVisualizer._build_official_image_path(image_format, image_base64)
if official_path is not None:
return PromptCLIVisualizer._build_file_uri(official_path), official_path
return build_file_uri(official_path), official_path
try:
image_bytes = b64decode(image_base64)
@@ -153,7 +147,7 @@ class PromptCLIVisualizer:
path.write_bytes(image_bytes)
except Exception:
return None
return PromptCLIVisualizer._build_file_uri(path), path
return build_file_uri(path), path
@classmethod
def _render_image_item(cls, image_format: str, image_base64: str, settings: PromptImageDisplaySettings) -> Panel:
@@ -169,8 +163,9 @@ class PromptCLIVisualizer:
path_result = cls._build_image_file_link(image_format, image_base64)
if path_result is not None:
file_uri, file_path = path_result
display_path = build_display_path(file_path)
preview_parts: List[RenderableType] = [
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{file_path}", style="magenta")
Text(f"图片格式 image/{normalized_format} {size_text} 路径:{display_path}", style="magenta")
]
preview_parts.append(Text.from_markup(f"[link={file_uri}]点击打开图片[/link]", style="cyan"))
@@ -437,17 +432,44 @@ class PromptCLIVisualizer:
)
file_uri, file_path = path_result
display_path = build_display_path(file_path)
return (
"<div class='image-card'>"
f"<div class='image-meta'>图片 image/{html.escape(normalized_format)} {html.escape(size_text)}</div>"
f"<a class='image-preview-link' href='{html.escape(file_uri, quote=True)}'>"
f"<img class='image-preview' src='{html.escape(file_uri, quote=True)}' alt='图片预览' />"
"</a>"
f"<div class='image-path'>{html.escape(str(file_path))}</div>"
f"<div class='image-path'>{html.escape(display_path)}</div>"
f"<a class='image-link' href='{html.escape(file_uri, quote=True)}'>打开图片</a>"
"</div>"
)
@staticmethod
def _build_preview_access_body(
*,
viewer_label: str,
viewer_path: Path,
viewer_link_text: str,
dump_label: str,
dump_path: Path,
dump_link_text: str,
) -> RenderableType:
viewer_uri = build_file_uri(viewer_path)
dump_uri = build_file_uri(dump_path)
viewer_display_path = build_display_path(viewer_path)
dump_display_path = build_display_path(dump_path)
return Group(
Text.from_markup(
f"[bold green]{viewer_label}{viewer_display_path}[/bold green] "
f"[link={viewer_uri}]{viewer_link_text}[/link]"
),
Text.from_markup(
f"[magenta]{dump_label}{dump_display_path}[/magenta] "
f"[cyan][link={dump_uri}]{dump_link_text}[/link][/cyan]"
),
)
@classmethod
def _build_html_role_class(cls, role: str) -> str:
return {
@@ -823,18 +845,13 @@ class PromptCLIVisualizer:
)
viewer_html_path = saved_paths[".html"]
prompt_dump_path = saved_paths[".txt"]
viewer_uri = cls._build_file_uri(viewer_html_path)
dump_uri = cls._build_file_uri(prompt_dump_path)
body = Group(
Text.from_markup(
f"[bold green]html预览{viewer_html_path}[/bold green] "
f"[link={viewer_uri}]在浏览器打开 Prompt [/link]"
),
Text.from_markup(
f"[magenta]原始文本:{prompt_dump_path}[/magenta] "
f"[cyan][link={dump_uri}]点击打开 Prompt 文本[/link][/cyan]"
),
body = cls._build_preview_access_body(
viewer_label="html预览",
viewer_path=viewer_html_path,
viewer_link_text="在浏览器打开 Prompt",
dump_label="原始文本",
dump_path=prompt_dump_path,
dump_link_text="点击打开 Prompt 文本",
)
return body
@@ -989,18 +1006,13 @@ class PromptCLIVisualizer:
)
viewer_html_path = saved_paths[".html"]
text_dump_path = saved_paths[".txt"]
viewer_uri = cls._build_file_uri(viewer_html_path)
dump_uri = cls._build_file_uri(text_dump_path)
body = Group(
Text.from_markup(
f"[bold green]富文本预览:{viewer_html_path}[/bold green] "
f"[link={viewer_uri}]点击在浏览器打开富文本 Prompt 视图[/link]"
),
Text.from_markup(
f"[magenta]原始文本备份:{text_dump_path}[/magenta] "
f"[cyan][link={dump_uri}]点击直接打开 Prompt 文本[/link][/cyan]"
),
body = cls._build_preview_access_body(
viewer_label="富文本预览",
viewer_path=viewer_html_path,
viewer_link_text="点击在浏览器打开富文本 Prompt 视图",
dump_label="原始文本备份",
dump_path=text_dump_path,
dump_link_text="点击直接打开 Prompt 文本",
)
return body

View File

@@ -2,11 +2,11 @@
from __future__ import annotations
import re
import time
from pathlib import Path
from typing import Dict
from uuid import uuid4
from .preview_path_utils import build_preview_chat_dir_name, normalize_preview_name
class PromptPreviewLogger:
@@ -15,14 +15,16 @@ class PromptPreviewLogger:
_BASE_DIR = Path("logs") / "maisaka_prompt"
_MAX_PREVIEW_GROUPS_PER_CHAT = 1024
_TRIM_COUNT = 100
_SAFE_NAME_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
@classmethod
def _normalize_chat_id(cls, chat_id: str) -> str:
normalized_chat_id = cls._SAFE_NAME_PATTERN.sub("_", str(chat_id or "").strip()).strip("._")
if normalized_chat_id:
return normalized_chat_id
return "unknown_chat"
def _build_file_stem(cls, chat_dir: Path) -> str:
base_stem = str(int(time.time() * 1000))
candidate_stem = base_stem
suffix_index = 1
while any((chat_dir / f"{candidate_stem}{suffix}").exists() for suffix in (".html", ".txt")):
candidate_stem = f"{base_stem}_{suffix_index}"
suffix_index += 1
return candidate_stem
@classmethod
def save_preview_files(
@@ -33,10 +35,10 @@ class PromptPreviewLogger:
) -> Dict[str, Path]:
"""保存同一份 Prompt 预览的多个文件并执行超量清理。"""
normalized_category = cls._normalize_chat_id(category)
chat_dir = (cls._BASE_DIR / normalized_category / cls._normalize_chat_id(chat_id)).resolve()
normalized_category = normalize_preview_name(category)
chat_dir = (cls._BASE_DIR / normalized_category / build_preview_chat_dir_name(chat_id)).resolve()
chat_dir.mkdir(parents=True, exist_ok=True)
stem = f"{int(time.time() * 1000)}_{uuid4().hex[:8]}"
stem = cls._build_file_stem(chat_dir)
saved_paths: Dict[str, Path] = {}
try:
for suffix, content in files.items():

View File

@@ -14,7 +14,7 @@ from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.config.config import config_manager, global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.exceptions import ReqAbortException
from src.llm_models.payload_content.tool_option import ToolCall
@@ -738,10 +738,47 @@ class MaisakaReasoningEngine:
planner_prefix: str,
) -> MessageSequence:
message_sequence = build_prefixed_message_sequence(message.raw_message, planner_prefix)
if global_config.visual.multimodal_planner:
if self._resolve_enable_visual_planner():
await self._hydrate_visual_components(message_sequence.components)
return message_sequence
@staticmethod
def _resolve_enable_visual_planner() -> bool:
planner_mode = global_config.visual.planner_mode
planner_task_config = config_manager.get_model_config().model_task_config.planner
models_by_name = {model.name: model for model in config_manager.get_model_config().models}
if planner_mode == "text":
return False
planner_models: list[str] = list(planner_task_config.model_list)
missing_models = [model_name for model_name in planner_models if model_name not in models_by_name]
non_visual_models = [
model_name for model_name in planner_models if model_name in models_by_name and not models_by_name[model_name].visual
]
if planner_mode == "multimodal":
if missing_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未定义的模型:"
f"{', '.join(missing_models)}"
)
if non_visual_models:
raise ValueError(
"planner_mode=multimodal但 planner 任务存在未开启 visual 的模型:"
f"{', '.join(non_visual_models)}"
)
return True
if missing_models:
logger.warning(
"planner_mode=auto 时发现 planner 任务存在未定义模型:"
f"{', '.join(missing_models)},将退化为纯文本 planner"
)
return False
return bool(planner_models) and not non_visual_models
async def _hydrate_visual_components(self, planner_components: list[object]) -> None:
"""在 Maisaka 真正需要图片或表情时,按需回填二进制数据。"""
load_tasks: list[asyncio.Task[None]] = []