feat: 更新多个文件以使用 SessionMessage 替代 MaiMessage,并调整相关逻辑

This commit is contained in:
DrSmoothl
2026-03-28 13:39:48 +08:00
parent a3bc145051
commit 7a460a474d
15 changed files with 136 additions and 84 deletions

View File

@@ -13,7 +13,7 @@ from rich.markdown import Markdown
from rich.panel import Panel
from rich.text import Text
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.config.config import global_config
from .config import (
@@ -26,8 +26,8 @@ from .config import (
from .input_reader import InputReader
from .knowledge import retrieve_relevant_knowledge
from .knowledge_store import get_knowledge_store
from .llm_service import MaiSakaLLMService, build_message, remove_last_perception
from .message_adapter import format_speaker_content
from .llm_service import MaiSakaLLMService
from .message_adapter import build_message, format_speaker_content, remove_last_perception
from .mcp_client import MCPManager
from .tool_handlers import (
ToolHandlerContext,
@@ -47,7 +47,7 @@ class BufferCLI:
def __init__(self):
self.llm_service: Optional[MaiSakaLLMService] = None
self._reader = InputReader()
self._chat_history: Optional[list[MaiMessage]] = None
self._chat_history: Optional[list[SessionMessage]] = None
self._knowledge_store = get_knowledge_store()
knowledge_stats = self._knowledge_store.get_stats()
@@ -122,7 +122,7 @@ class BufferCLI:
await self._run_llm_loop(self._chat_history)
async def _run_llm_loop(self, chat_history: list[MaiMessage]):
async def _run_llm_loop(self, chat_history: list[SessionMessage]):
"""
Main inner loop for the Maisaka planner.
@@ -318,7 +318,7 @@ class BufferCLI:
)
)
async def _generate_visible_reply(self, chat_history: list[MaiMessage], latest_thought: str) -> str:
async def _generate_visible_reply(self, chat_history: list[SessionMessage], latest_thought: str) -> str:
"""Generate and emit a visible reply based on the latest thought."""
if not self.llm_service or not latest_thought:
return ""

View File

@@ -4,7 +4,7 @@ MaiSaka knowledge retrieval helpers.
from typing import List
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from .knowledge_store import KNOWLEDGE_CATEGORIES, get_knowledge_store
@@ -43,7 +43,7 @@ def extract_category_ids_from_result(result: str) -> List[str]:
async def retrieve_relevant_knowledge(
llm_service,
chat_history: List[MaiMessage],
chat_history: List[SessionMessage],
) -> str:
"""Retrieve formatted knowledge snippets relevant to the current chat history."""
store = get_knowledge_store()

View File

@@ -19,7 +19,8 @@ from rich.panel import Panel
from rich.pretty import Pretty
from rich.text import Text
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import config_manager, global_config
@@ -31,7 +32,6 @@ from src.llm_models.payload_content.tool_option import (
ToolOption,
normalize_tool_options,
)
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.services.llm_service import LLMServiceClient
from . import config
@@ -55,7 +55,7 @@ class ChatResponse:
content: Optional[str]
tool_calls: List[ToolCall]
raw_message: MaiMessage
raw_message: SessionMessage
class MaiSakaLLMService:
@@ -428,7 +428,7 @@ class MaiSakaLLMService:
padding=(0, 1),
)
async def chat_loop_step(self, chat_history: List[MaiMessage]) -> ChatResponse:
async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse:
"""执行主对话循环的一步。
Args:
@@ -514,7 +514,7 @@ class MaiSakaLLMService:
source="assistant",
tool_calls=tool_calls or None,
)
logger.info("已将规划模型响应转换为 MaiMessage")
logger.info("已将规划模型响应转换为 SessionMessage")
return ChatResponse(
content=response,
@@ -522,7 +522,7 @@ class MaiSakaLLMService:
raw_message=raw_message,
)
def _filter_for_api(self, chat_history: List[MaiMessage]) -> str:
def _filter_for_api(self, chat_history: List[SessionMessage]) -> str:
"""将对话历史过滤为简单文本格式。
Args:
@@ -555,14 +555,14 @@ class MaiSakaLLMService:
return "\n\n".join(parts)
def build_chat_context(self, user_text: str) -> List[MaiMessage]:
def build_chat_context(self, user_text: str) -> List[SessionMessage]:
"""构建新的对话上下文。
Args:
user_text: 用户输入文本。
Returns:
List[MaiMessage]: 初始对话上下文消息列表。
List[SessionMessage]: 初始对话上下文消息列表。
"""
return [
build_message(
@@ -572,7 +572,7 @@ class MaiSakaLLMService:
)
]
async def _removed_analyze_timing(self, chat_history: List[MaiMessage], timing_info: str) -> str:
async def _removed_analyze_timing(self, chat_history: List[SessionMessage], timing_info: str) -> str:
"""执行时间节奏分析。
Args:
@@ -623,7 +623,7 @@ class MaiSakaLLMService:
# ──────── 回复生成(使用 replyer 模型) ────────
async def generate_reply(self, reason: str, chat_history: List[MaiMessage]) -> str:
async def generate_reply(self, reason: str, chat_history: List[SessionMessage]) -> str:
"""生成最终回复文本。
Args:

View File

@@ -1,5 +1,5 @@
"""
MaiSaka message adapters built on top of the main project's MaiMessage model.
MaiSaka 内部消息适配器。
"""
from copy import deepcopy
@@ -12,7 +12,8 @@ import re
from PIL import Image as PILImage
from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, MessageInfo, UserInfo
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo, UserInfo
from src.common.data_models.message_component_data_model import EmojiComponent, ImageComponent, MessageSequence, TextComponent
from src.config.config import global_config
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
@@ -77,11 +78,11 @@ def build_message(
group_info: Optional[GroupInfo] = None,
raw_message: Optional[MessageSequence] = None,
display_text: Optional[str] = None,
) -> MaiMessage:
"""Build a MaiMessage for the Maisaka session history."""
) -> SessionMessage:
""" MaiSaka 会话历史构建内部 ``SessionMessage``。"""
resolved_timestamp = timestamp or datetime.now()
resolved_role = role.value if isinstance(role, RoleType) else role
message = MaiMessage(
message = SessionMessage(
message_id=message_id or f"maisaka_{uuid4().hex}",
timestamp=resolved_timestamp,
platform=platform,
@@ -104,6 +105,7 @@ def build_message(
visible_text = display_text if display_text is not None else content
message.processed_plain_text = visible_text
message.display_message = visible_text
message.initialized = True
return message
@@ -160,7 +162,7 @@ def _guess_image_format(image_bytes: bytes) -> Optional[str]:
return None
def get_message_text(message: MaiMessage) -> str:
def get_message_text(message: SessionMessage) -> str:
if message.processed_plain_text is not None:
return message.processed_plain_text
if message.display_message is not None:
@@ -174,42 +176,42 @@ def get_message_text(message: MaiMessage) -> str:
return "".join(parts)
def get_message_role(message: MaiMessage) -> str:
def get_message_role(message: SessionMessage) -> str:
return str(message.message_info.additional_config.get(LLM_ROLE_KEY, RoleType.User.value))
def get_message_kind(message: MaiMessage) -> str:
def get_message_kind(message: SessionMessage) -> str:
return str(message.message_info.additional_config.get(MESSAGE_KIND_KEY, "normal"))
def get_message_source(message: MaiMessage) -> str:
def get_message_source(message: SessionMessage) -> str:
return str(message.message_info.additional_config.get(SOURCE_KEY, get_message_role(message)))
def is_perception_message(message: MaiMessage) -> bool:
def is_perception_message(message: SessionMessage) -> bool:
return get_message_kind(message) == "perception"
def get_tool_call_id(message: MaiMessage) -> Optional[str]:
def get_tool_call_id(message: SessionMessage) -> Optional[str]:
value = message.message_info.additional_config.get(TOOL_CALL_ID_KEY)
return str(value) if value else None
def get_tool_calls(message: MaiMessage) -> list[ToolCall]:
def get_tool_calls(message: SessionMessage) -> list[ToolCall]:
raw_tool_calls = message.message_info.additional_config.get(TOOL_CALLS_KEY, [])
if not isinstance(raw_tool_calls, list):
return []
return [_deserialize_tool_call(item) for item in raw_tool_calls if isinstance(item, dict)]
def remove_last_perception(messages: list[MaiMessage]) -> None:
def remove_last_perception(messages: list[SessionMessage]) -> None:
for index in range(len(messages) - 1, -1, -1):
if is_perception_message(messages[index]):
messages.pop(index)
break
def to_llm_message(message: MaiMessage) -> Optional[Message]:
def to_llm_message(message: SessionMessage) -> Optional[Message]:
role = get_message_role(message)
tool_call_id = get_tool_call_id(message)
tool_calls = get_tool_calls(message)

View File

@@ -4,7 +4,7 @@ MaiSaka reply helper.
from typing import Optional
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.config.config import global_config
from .config import USER_NAME
@@ -19,17 +19,17 @@ def _normalize_content(content: str, limit: int = 500) -> str:
return normalized
def _format_message_time(message: MaiMessage) -> str:
def _format_message_time(message: SessionMessage) -> str:
return message.timestamp.strftime("%H:%M:%S")
def _extract_visible_assistant_reply(message: MaiMessage) -> str:
def _extract_visible_assistant_reply(message: SessionMessage) -> str:
if is_perception_message(message):
return ""
return ""
def _extract_guided_bot_reply(message: MaiMessage) -> str:
def _extract_guided_bot_reply(message: SessionMessage) -> str:
speaker_name, body = parse_speaker_content(get_message_text(message).strip())
bot_nickname = global_config.bot.nickname.strip() or "Bot"
if speaker_name == bot_nickname:
@@ -64,7 +64,7 @@ def _split_user_message_segments(raw_content: str) -> list[tuple[Optional[str],
return segments
def format_chat_history(messages: list[MaiMessage]) -> str:
def format_chat_history(messages: list[SessionMessage]) -> str:
"""Format visible chat history for reply generation."""
bot_nickname = global_config.bot.nickname.strip() or "Bot"
parts: list[str] = []
@@ -109,7 +109,7 @@ class Replyer:
def set_enabled(self, enabled: bool) -> None:
self._enabled = enabled
async def reply(self, reason: str, chat_history: list[MaiMessage]) -> str:
async def reply(self, reason: str, chat_history: list[SessionMessage]) -> str:
if not self._enabled or not reason or self._llm_service is None:
return "..."

View File

@@ -12,7 +12,7 @@ import asyncio
from src.chat.heart_flow.heartFC_utils import CycleDetail
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.mai_message_data_model import GroupInfo, MaiMessage, UserInfo
from src.common.data_models.mai_message_data_model import GroupInfo, UserInfo
from src.common.data_models.message_component_data_model import MessageSequence
from src.common.logger import get_logger
from src.config.config import global_config
@@ -56,7 +56,7 @@ class MaisakaHeartFlowChatting:
session_name = chat_manager.get_session_name(session_id) or session_id
self.log_prefix = f"[{session_name}]"
self._llm_service = MaiSakaLLMService(api_key="", base_url=None, model="")
self._chat_history: list[MaiMessage] = []
self._chat_history: list[SessionMessage] = []
self.history_loop: list[CycleDetail] = []
self.message_cache: list[SessionMessage] = []
self._mcp_manager: Optional[MCPManager] = None
@@ -227,7 +227,7 @@ class MaisakaHeartFlowChatting:
return merged_sequence
async def _build_user_history_message(self, message: SessionMessage) -> Optional[MaiMessage]:
async def _build_user_history_message(self, message: SessionMessage) -> Optional[SessionMessage]:
user_sequence = await self._build_message_sequence(message)
visible_text = build_visible_text_from_sequence(user_sequence).strip()
if not user_sequence.components:
@@ -498,7 +498,7 @@ class MaisakaHeartFlowChatting:
)
return True
def _build_tool_message(self, tool_call: ToolCall, content: str) -> MaiMessage:
def _build_tool_message(self, tool_call: ToolCall, content: str) -> SessionMessage:
return build_message(
role="tool",
content=content,

View File

@@ -11,7 +11,7 @@ import os
from rich.panel import Panel
from src.common.data_models.mai_message_data_model import MaiMessage
from src.chat.message_receive.message import SessionMessage
from src.llm_models.payload_content.tool_option import ToolCall
from .config import console
@@ -41,7 +41,7 @@ class ToolHandlerContext:
self.last_user_input_time: Optional[datetime] = None
async def handle_stop(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
async def handle_stop(tc: ToolCall, chat_history: list[SessionMessage]) -> None:
"""Handle the stop tool."""
console.print("[accent]Calling tool: stop()[/accent]")
chat_history.append(
@@ -49,7 +49,7 @@ async def handle_stop(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
)
async def handle_wait(tc: ToolCall, chat_history: list[MaiMessage], ctx: ToolHandlerContext) -> str:
async def handle_wait(tc: ToolCall, chat_history: list[SessionMessage], ctx: ToolHandlerContext) -> str:
"""Handle the wait tool."""
seconds = (tc.args or {}).get("seconds", 30)
seconds = max(5, min(seconds, 300))
@@ -86,7 +86,7 @@ async def _do_wait(seconds: int, ctx: ToolHandlerContext) -> str:
return f"User input received: {user_input}"
async def handle_mcp_tool(tc: ToolCall, chat_history: list[MaiMessage], mcp_manager: "MCPManager") -> None:
async def handle_mcp_tool(tc: ToolCall, chat_history: list[SessionMessage], mcp_manager: "MCPManager") -> None:
"""Handle an MCP tool call."""
args_str = _json.dumps(tc.args or {}, ensure_ascii=False)
args_preview = args_str if len(args_str) <= 120 else args_str[:120] + "..."
@@ -107,13 +107,13 @@ async def handle_mcp_tool(tc: ToolCall, chat_history: list[MaiMessage], mcp_mana
chat_history.append(build_message(role="tool", content=result, tool_call_id=tc.call_id))
async def handle_unknown_tool(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
async def handle_unknown_tool(tc: ToolCall, chat_history: list[SessionMessage]) -> None:
"""Handle an unknown tool call."""
console.print(f"[accent]Calling unknown tool: {tc.func_name}({tc.args})[/accent]")
chat_history.append(build_message(role="tool", content=f"Unknown tool: {tc.func_name}", tool_call_id=tc.call_id))
async def handle_write_file(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
async def handle_write_file(tc: ToolCall, chat_history: list[SessionMessage]) -> None:
"""Write a file under the local mai_files workspace."""
filename = (tc.args or {}).get("filename", "")
content = (tc.args or {}).get("content", "")
@@ -149,7 +149,7 @@ async def handle_write_file(tc: ToolCall, chat_history: list[MaiMessage]) -> Non
chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id))
async def handle_read_file(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
async def handle_read_file(tc: ToolCall, chat_history: list[SessionMessage]) -> None:
"""Read a file from the local mai_files workspace."""
filename = (tc.args or {}).get("filename", "")
console.print(f'[accent]Calling tool: read_file("{filename}")[/accent]')
@@ -190,7 +190,7 @@ async def handle_read_file(tc: ToolCall, chat_history: list[MaiMessage]) -> None
chat_history.append(build_message(role="tool", content=error_msg, tool_call_id=tc.call_id))
async def handle_list_files(tc: ToolCall, chat_history: list[MaiMessage]) -> None:
async def handle_list_files(tc: ToolCall, chat_history: list[SessionMessage]) -> None:
"""List files under the local mai_files workspace."""
console.print("[accent]Calling tool: list_files()[/accent]")