feat:重构maisaka的消息类型,添加打断功能
This commit is contained in:
@@ -14,9 +14,9 @@ from rich.panel import Panel
|
||||
from rich.pretty import Pretty
|
||||
from rich.text import Text
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.cli.console import console
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, TextComponent
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
@@ -27,12 +27,8 @@ from src.llm_models.payload_content.tool_option import ToolCall, ToolDefinitionI
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from .builtin_tools import get_builtin_tools
|
||||
from .message_adapter import (
|
||||
build_message,
|
||||
format_speaker_content,
|
||||
get_message_role,
|
||||
to_llm_message,
|
||||
)
|
||||
from .context_messages import AssistantMessage, LLMContextMessage, SessionBackedMessage
|
||||
from .message_adapter import format_speaker_content
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -41,7 +37,7 @@ class ChatResponse:
|
||||
|
||||
content: Optional[str]
|
||||
tool_calls: List[ToolCall]
|
||||
raw_message: SessionMessage
|
||||
raw_message: AssistantMessage
|
||||
|
||||
|
||||
logger = get_logger("maisaka_chat_loop")
|
||||
@@ -59,6 +55,7 @@ class MaisakaChatLoopService:
|
||||
self._temperature = temperature
|
||||
self._max_tokens = max_tokens
|
||||
self._extra_tools: List[ToolOption] = []
|
||||
self._interrupt_flag: asyncio.Event | None = None
|
||||
self._prompts_loaded = False
|
||||
self._prompt_load_lock = asyncio.Lock()
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
@@ -117,18 +114,21 @@ class MaisakaChatLoopService:
|
||||
def set_extra_tools(self, tools: List[ToolDefinitionInput]) -> None:
|
||||
self._extra_tools = normalize_tool_options(tools) or []
|
||||
|
||||
def set_interrupt_flag(self, interrupt_flag: asyncio.Event | None) -> None:
|
||||
"""设置当前 planner 请求使用的中断标记。"""
|
||||
self._interrupt_flag = interrupt_flag
|
||||
|
||||
async def analyze_knowledge_need(
|
||||
self,
|
||||
chat_history: List[SessionMessage],
|
||||
chat_history: List[LLMContextMessage],
|
||||
categories_summary: str,
|
||||
) -> List[str]:
|
||||
"""分析当前对话是否需要检索知识库分类。"""
|
||||
visible_history: List[str] = []
|
||||
for message in chat_history[-8:]:
|
||||
if not message.content:
|
||||
if not message.processed_plain_text:
|
||||
continue
|
||||
role = getattr(message, "role", "")
|
||||
visible_history.append(f"{role}: {message.content}")
|
||||
visible_history.append(f"{message.role}: {message.processed_plain_text}")
|
||||
|
||||
if not visible_history or not categories_summary.strip():
|
||||
return []
|
||||
@@ -302,7 +302,7 @@ class MaisakaChatLoopService:
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
async def chat_loop_step(self, chat_history: List[SessionMessage]) -> ChatResponse:
|
||||
async def chat_loop_step(self, chat_history: List[LLMContextMessage]) -> ChatResponse:
|
||||
await self.ensure_chat_prompt_loaded()
|
||||
selected_history, selection_reason = self._select_llm_context_messages(chat_history)
|
||||
|
||||
@@ -313,7 +313,7 @@ class MaisakaChatLoopService:
|
||||
messages.append(system_msg.build())
|
||||
|
||||
for msg in selected_history:
|
||||
llm_message = to_llm_message(msg)
|
||||
llm_message = msg.to_llm_message()
|
||||
if llm_message is not None:
|
||||
messages.append(llm_message)
|
||||
|
||||
@@ -342,15 +342,24 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
request_started_at = perf_counter()
|
||||
logger.info(
|
||||
"planner 请求开始: "
|
||||
f"selected_history={len(selected_history)} "
|
||||
f"llm_messages={len(built_messages)} "
|
||||
f"tool_count={len(all_tools)} "
|
||||
f"interrupt_enabled={self._interrupt_flag is not None}"
|
||||
)
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
tool_options=all_tools if all_tools else None,
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_tokens,
|
||||
interrupt_flag=self._interrupt_flag,
|
||||
),
|
||||
)
|
||||
_ = perf_counter() - request_started_at
|
||||
request_elapsed = perf_counter() - request_started_at
|
||||
logger.info(f"planner 请求完成,elapsed={request_elapsed:.3f}s")
|
||||
|
||||
tool_call_summaries = [
|
||||
{
|
||||
@@ -365,11 +374,10 @@ class MaisakaChatLoopService:
|
||||
f"tool_calls={tool_call_summaries}"
|
||||
)
|
||||
|
||||
raw_message = build_message(
|
||||
role=RoleType.Assistant.value,
|
||||
raw_message = AssistantMessage(
|
||||
content=generation_result.response or "",
|
||||
source="assistant",
|
||||
tool_calls=generation_result.tool_calls or None,
|
||||
timestamp=datetime.now(),
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
)
|
||||
return ChatResponse(
|
||||
content=generation_result.response,
|
||||
@@ -378,20 +386,19 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _select_llm_context_messages(chat_history: List[SessionMessage]) -> tuple[List[SessionMessage], str]:
|
||||
def _select_llm_context_messages(chat_history: List[LLMContextMessage]) -> tuple[List[LLMContextMessage], str]:
|
||||
"""选择真正发送给 LLM 的上下文消息。"""
|
||||
max_context_size = max(1, int(global_config.chat.max_context_size))
|
||||
counted_roles = {"user", "assistant"}
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
for index in range(len(chat_history) - 1, -1, -1):
|
||||
message = chat_history[index]
|
||||
if to_llm_message(message) is None:
|
||||
if message.to_llm_message() is None:
|
||||
continue
|
||||
|
||||
selected_indices.append(index)
|
||||
if get_message_role(message) in counted_roles:
|
||||
if message.count_in_context:
|
||||
counted_message_count += 1
|
||||
if counted_message_count >= max_context_size:
|
||||
break
|
||||
@@ -410,15 +417,25 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_chat_context(user_text: str) -> List[SessionMessage]:
|
||||
def build_chat_context(user_text: str) -> List[LLMContextMessage]:
|
||||
timestamp = datetime.now()
|
||||
visible_text = format_speaker_content(
|
||||
global_config.maisaka.user_name.strip() or "用户",
|
||||
user_text,
|
||||
timestamp,
|
||||
)
|
||||
planner_prefix = (
|
||||
f"[时间]{timestamp.strftime('%H:%M:%S')}\n"
|
||||
f"[用户]{global_config.maisaka.user_name.strip() or '用户'}\n"
|
||||
"[用户群昵称]\n"
|
||||
"[msg_id]\n"
|
||||
"[发言内容]"
|
||||
)
|
||||
return [
|
||||
build_message(
|
||||
role=RoleType.User.value,
|
||||
content=format_speaker_content(
|
||||
global_config.maisaka.user_name.strip() or "用户",
|
||||
user_text,
|
||||
datetime.now(),
|
||||
),
|
||||
source="user",
|
||||
SessionBackedMessage(
|
||||
raw_message=MessageSequence([TextComponent(f"{planner_prefix}{user_text}")]),
|
||||
visible_text=visible_text,
|
||||
timestamp=timestamp,
|
||||
source_kind="user",
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user