Files
mai-bot/src/maisaka/chat_loop_service.py
2026-03-29 18:28:56 +08:00

425 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from dataclasses import dataclass
from base64 import b64decode
from datetime import datetime
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional
import asyncio
import random
from PIL import Image as PILImage
from rich.console import Group
from rich.panel import Panel
from rich.pretty import Pretty
from rich.text import Text
from src.chat.message_receive.message import SessionMessage
from src.cli.console import console
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.know_u.knowledge import extract_category_ids_from_result
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionInput, ToolOption, normalize_tool_options
from src.services.llm_service import LLMServiceClient
from .builtin_tools import get_builtin_tools
from .message_adapter import (
build_message,
format_speaker_content,
get_message_role,
to_llm_message,
)
@dataclass(slots=True)
class ChatResponse:
"""LLM 对话循环单步响应。"""
content: Optional[str]
tool_calls: List[ToolCall]
raw_message: SessionMessage
logger = get_logger("maisaka_chat_loop")
class MaisakaChatLoopService:
"""负责 Maisaka 主对话循环、系统提示词和终端渲染。"""
def __init__(
self,
chat_system_prompt: Optional[str] = None,
temperature: float = 0.5,
max_tokens: int = 2048,
) -> None:
self._temperature = temperature
self._max_tokens = max_tokens
self._extra_tools: List[ToolOption] = []
self._prompts_loaded = False
self._prompt_load_lock = asyncio.Lock()
self._personality_prompt = self._build_personality_prompt()
if chat_system_prompt is None:
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
else:
self._chat_system_prompt = chat_system_prompt
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
@property
def personality_prompt(self) -> str:
return self._personality_prompt
def _build_personality_prompt(self) -> str:
try:
bot_name = global_config.bot.nickname
if global_config.bot.alias_names:
bot_nickname = f", also known as {','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
prompt_personality = global_config.personality.personality
if (
hasattr(global_config.personality, "states")
and global_config.personality.states
and hasattr(global_config.personality, "state_probability")
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
return f"Your name is {bot_name}{bot_nickname}; persona: {prompt_personality};"
except Exception:
return "Your name is MaiMai; persona: lively and cute AI assistant."
async def ensure_chat_prompt_loaded(self, tools_section: str = "") -> None:
if self._prompts_loaded:
return
async with self._prompt_load_lock:
if self._prompts_loaded:
return
try:
self._chat_system_prompt = load_prompt(
"maidairy_chat",
file_tools_section=tools_section,
bot_name=global_config.bot.nickname,
identity=self._personality_prompt,
)
except Exception:
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
self._prompts_loaded = True
def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None:
self._extra_tools = normalize_tool_options(tools) or []
async def analyze_knowledge_need(
self,
chat_history: List[SessionMessage],
categories_summary: str,
) -> List[str]:
"""分析当前对话是否需要检索知识库分类。"""
visible_history: List[str] = []
for message in chat_history[-8:]:
if not message.content:
continue
role = getattr(message, "role", "")
visible_history.append(f"{role}: {message.content}")
if not visible_history or not categories_summary.strip():
return []
prompt = (
"你需要判断当前对话是否需要查询知识库。\n"
"请只返回最相关的分类编号,多个编号用空格分隔;如果完全不需要,返回 none。\n\n"
f"【可用分类】\n{categories_summary}\n\n"
f"【最近对话】\n{chr(10).join(visible_history)}"
)
try:
generation_result = await self._llm_chat.generate_response(
prompt=prompt,
options=LLMGenerationOptions(
temperature=0.1,
max_tokens=64,
),
)
except Exception:
return []
return extract_category_ids_from_result(generation_result.response or "")
@staticmethod
def _get_role_badge_style(role: str) -> str:
if role == "system":
return "bold white on blue"
if role == "user":
return "bold black on green"
if role == "assistant":
return "bold black on yellow"
if role == "tool":
return "bold white on magenta"
return "bold white on bright_black"
@staticmethod
def _build_terminal_image_preview(image_base64: str) -> Optional[str]:
ascii_chars = " .:-=+*#%@"
try:
image_bytes = b64decode(image_base64)
with PILImage.open(BytesIO(image_bytes)) as image:
grayscale = image.convert("L")
width, height = grayscale.size
if width <= 0 or height <= 0:
return None
preview_width = max(8, int(global_config.maisaka.terminal_image_preview_width))
preview_height = max(1, int(height * (preview_width / width) * 0.5))
resized = grayscale.resize((preview_width, preview_height))
pixels = list(resized.getdata())
except Exception:
return None
rows: List[str] = []
for row_index in range(preview_height):
row_pixels = pixels[row_index * preview_width : (row_index + 1) * preview_width]
row = "".join(ascii_chars[min(len(ascii_chars) - 1, pixel * len(ascii_chars) // 256)] for pixel in row_pixels)
rows.append(row)
return "\n".join(rows)
@classmethod
def _render_message_content(cls, content: Any) -> object:
if isinstance(content, str):
return Text(content)
if isinstance(content, list):
parts: List[object] = []
for item in content:
if isinstance(item, str):
parts.append(Text(item))
continue
if isinstance(item, tuple) and len(item) == 2:
image_format, image_base64 = item
if isinstance(image_format, str) and isinstance(image_base64, str):
approx_size = max(0, len(image_base64) * 3 // 4)
size_text = f"{approx_size / 1024:.1f} KB" if approx_size >= 1024 else f"{approx_size} B"
preview_parts: List[object] = [
Text(f"image/{image_format} {size_text}\nbase64 omitted", style="magenta")
]
if global_config.maisaka.terminal_image_preview:
preview_text = cls._build_terminal_image_preview(image_base64)
if preview_text:
preview_parts.append(Text(preview_text, style="white"))
parts.append(
Panel(
Group(*preview_parts),
border_style="magenta",
padding=(0, 1),
)
)
continue
if isinstance(item, dict) and item.get("type") == "text" and isinstance(item.get("text"), str):
parts.append(Text(item["text"]))
else:
parts.append(Pretty(item, expand_all=True))
return Group(*parts) if parts else Text("")
if content is None:
return Text("")
return Pretty(content, expand_all=True)
@staticmethod
def _format_tool_call_for_display(tool_call: Any) -> Dict[str, Any]:
if isinstance(tool_call, dict):
function_info = tool_call.get("function", {})
return {
"id": tool_call.get("id"),
"name": function_info.get("name", tool_call.get("name")),
"arguments": function_info.get("arguments", tool_call.get("arguments")),
}
return {
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"arguments": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
def _render_tool_call_panel(self, tool_call: Any, index: int, parent_index: int) -> Panel:
title = Text.assemble(
Text(" TOOL CALL ", style="bold white on magenta"),
Text(f" #{parent_index}.{index}", style="muted"),
)
return Panel(
Pretty(self._format_tool_call_for_display(tool_call), expand_all=True),
title=title,
border_style="magenta",
padding=(0, 1),
)
def _render_message_panel(self, message: Any, index: int) -> Panel:
if isinstance(message, dict):
raw_role = message.get("role", "unknown")
content = message.get("content")
tool_call_id = message.get("tool_call_id")
else:
raw_role = getattr(message, "role", "unknown")
content = getattr(message, "content", None)
tool_call_id = getattr(message, "tool_call_id", None)
role = raw_role.value if hasattr(raw_role, "value") else str(raw_role)
title = Text.assemble(
Text(f" {role.upper()} ", style=self._get_role_badge_style(role)),
Text(f" #{index}", style="muted"),
)
parts: List[object] = []
if content not in (None, "", []):
parts.append(Text(" message ", style="bold cyan"))
parts.append(self._render_message_content(content))
if tool_call_id:
parts.append(
Text.assemble(
Text(" tool_call_id ", style="bold magenta"),
Text(" "),
Text(str(tool_call_id), style="magenta"),
)
)
if not parts:
parts.append(Text("[empty message]", style="muted"))
return Panel(
Group(*parts),
title=title,
border_style="dim",
padding=(0, 1),
)
async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse:
await self.ensure_chat_prompt_loaded()
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
def message_factory(_client: BaseClient) -> List[Message]:
messages: List[Message] = []
system_msg = MessageBuilder().set_role(RoleType.System)
system_msg.add_text_content(self._chat_system_prompt)
messages.append(system_msg.build())
for msg in selected_history:
llm_message = to_llm_message(msg)
if llm_message is not None:
messages.append(llm_message)
return messages
all_tools = [*get_builtin_tools(), *self._extra_tools]
built_messages = message_factory(None)
ordered_panels: List[Panel] = []
for index, msg in enumerate(built_messages, start=1):
ordered_panels.append(self._render_message_panel(msg, index))
tool_calls = getattr(msg, "tool_calls", None)
if tool_calls:
for tool_call_index, tool_call in enumerate(tool_calls, start=1):
ordered_panels.append(self._render_tool_call_panel(tool_call, tool_call_index, index))
if global_config.maisaka.show_thinking and ordered_panels:
console.print(
Panel(
Group(*ordered_panels),
title="MaiSaka LLM Request - chat_loop_step",
subtitle=selection_reason,
border_style="cyan",
padding=(0, 1),
)
)
request_started_at = perf_counter()
generation_result = await self._llm_chat.generate_response_with_messages(
message_factory=message_factory,
options=LLMGenerationOptions(
tool_options=all_tools if all_tools else None,
temperature=self._temperature,
max_tokens=self._max_tokens,
),
)
_ = perf_counter() - request_started_at
tool_call_summaries = [
{
"id": getattr(tool_call, "call_id", getattr(tool_call, "id", None)),
"name": getattr(tool_call, "func_name", getattr(tool_call, "name", None)),
"args": getattr(tool_call, "args", getattr(tool_call, "arguments", None)),
}
for tool_call in (generation_result.tool_calls or [])
]
logger.info(
f"Maisaka planner returned content={generation_result.response or ''!r} "
f"tool_calls={tool_call_summaries}"
)
raw_message = build_message(
role=RoleType.Assistant.value,
content=generation_result.response or "",
source="assistant",
tool_calls=generation_result.tool_calls or None,
)
return ChatResponse(
content=generation_result.response,
tool_calls=generation_result.tool_calls or [],
raw_message=raw_message,
)
@staticmethod
def _select_llm_context_messages(chat_history: List[SessionMessage]) -> tuple[List[SessionMessage], str]:
"""选择真正发送给 LLM 的上下文消息。"""
max_context_size = max(1, int(global_config.chat.max_context_size))
counted_roles = {"user", "assistant"}
selected_indices: List[int] = []
counted_message_count = 0
for index in range(len(chat_history) - 1, -1, -1):
message = chat_history[index]
if to_llm_message(message) is None:
continue
selected_indices.append(index)
if get_message_role(message) in counted_roles:
counted_message_count += 1
if counted_message_count >= max_context_size:
break
if not selected_indices:
return [], f"上下文判定:最近 {max_context_size} 条 user/assistant当前 0 条)"
selected_indices.reverse()
selected_history = [chat_history[index] for index in selected_indices]
return (
selected_history,
(
f"上下文判定:最近 {max_context_size} 条 user/assistant"
f"展示并发送窗口内消息 {len(selected_history)}"
),
)
@staticmethod
def build_chat_context(user_text: str) -> List[SessionMessage]:
return [
build_message(
role=RoleType.User.value,
content=format_speaker_content(
global_config.maisaka.user_name.strip() or "用户",
user_text,
datetime.now(),
),
source="user",
)
]