Merge branch 'Mai-with-u:dev' into dev
This commit is contained in:
@@ -36,9 +36,9 @@ def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
|
||||
# 默认不再强制虚拟群聊;WebUI 默认走私聊频道,需要的话由调用者传入虚拟群 ID。
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, None)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None, None)
|
||||
return _webui_chat_broadcaster
|
||||
@@ -98,6 +98,14 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
# 私聊场景下出站消息的 user_info 是机器人自己的身份,
|
||||
# 真正的接收者用户 ID 由 send_service 写入 ``platform_io_target_user_id``。
|
||||
target_user_id = ""
|
||||
additional_config = message.message_info.additional_config or {}
|
||||
raw_target_user_id = additional_config.get("platform_io_target_user_id")
|
||||
if raw_target_user_id:
|
||||
target_user_id = str(raw_target_user_id).strip()
|
||||
|
||||
await chat_manager.broadcast_to_group(
|
||||
group_id=group_id or default_group_id or "",
|
||||
message={
|
||||
@@ -113,6 +121,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"is_bot": True,
|
||||
},
|
||||
},
|
||||
user_id=target_user_id,
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
|
||||
@@ -83,6 +83,7 @@ class BaseMaisakaReplyGenerator:
|
||||
self.express_model = llm_client_cls(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
session_id=getattr(chat_stream, "session_id", "") if chat_stream is not None else "",
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.9.19"
|
||||
CONFIG_VERSION: str = "8.9.20"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.3"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
@@ -1324,6 +1324,15 @@ class DebugConfig(ConfigBase):
|
||||
)
|
||||
"""是否记录 Replyer 请求体,默认关闭"""
|
||||
|
||||
enable_llm_cache_stats: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "chart-no-axes-column",
|
||||
},
|
||||
)
|
||||
"""是否记录 LLM prompt cache 调试统计,默认关闭"""
|
||||
|
||||
|
||||
class ExtraPromptItem(ConfigBase):
|
||||
platform: str = Field(
|
||||
|
||||
@@ -23,7 +23,7 @@ from .expression_utils import is_single_char_jargon
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
llm_inference = LLMServiceClient(task_name="replyer", request_type="jargon.inference")
|
||||
llm_inference = LLMServiceClient(task_name="utils", request_type="jargon.inference")
|
||||
|
||||
|
||||
class JargonEntry(TypedDict):
|
||||
|
||||
@@ -41,6 +41,11 @@ from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .visual_mode_utils import resolve_enable_visual_planner
|
||||
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
REQUEST_TYPE_BY_REQUEST_KIND = {
|
||||
"planner": "maisaka_planner",
|
||||
"timing_gate": "maisaka_timing_gate",
|
||||
}
|
||||
CONTEXT_SELECTION_CACHE_STABILITY_RATIO = 2.0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -212,7 +217,7 @@ class MaisakaChatLoopService:
|
||||
self._chat_system_prompt = f"{self._personality_prompt}\n\nYou are a helpful AI assistant."
|
||||
else:
|
||||
self._chat_system_prompt = chat_system_prompt
|
||||
self._llm_chat = LLMServiceClient(task_name="planner", request_type="maisaka_planner")
|
||||
self._llm_chat_clients: dict[str, LLMServiceClient] = {}
|
||||
|
||||
@property
|
||||
def personality_prompt(self) -> str:
|
||||
@@ -220,6 +225,30 @@ class MaisakaChatLoopService:
|
||||
|
||||
return self._personality_prompt
|
||||
|
||||
@staticmethod
|
||||
def _resolve_llm_request_type(request_kind: str) -> str:
|
||||
"""根据 Maisaka 请求类型解析 LLM 统计口径。"""
|
||||
|
||||
normalized_request_kind = str(request_kind or "").strip()
|
||||
return REQUEST_TYPE_BY_REQUEST_KIND.get(
|
||||
normalized_request_kind,
|
||||
f"maisaka_{normalized_request_kind}" if normalized_request_kind else "maisaka_planner",
|
||||
)
|
||||
|
||||
def _get_llm_chat_client(self, request_kind: str) -> LLMServiceClient:
|
||||
"""获取当前请求类型对应的 planner LLM 客户端。"""
|
||||
|
||||
request_type = self._resolve_llm_request_type(request_kind)
|
||||
llm_client = self._llm_chat_clients.get(request_type)
|
||||
if llm_client is None:
|
||||
llm_client = LLMServiceClient(
|
||||
task_name="planner",
|
||||
request_type=request_type,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
self._llm_chat_clients[request_type] = llm_client
|
||||
return llm_client
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
@@ -321,7 +350,13 @@ class MaisakaChatLoopService:
|
||||
|
||||
@staticmethod
|
||||
def _build_time_block() -> str:
|
||||
"""构建当前时间提示块。"""
|
||||
"""构建静态时间提示块。"""
|
||||
|
||||
return "当前时间会在每次请求末尾以用户消息形式提供。"
|
||||
|
||||
@staticmethod
|
||||
def _build_current_time_user_message() -> str:
|
||||
"""构建追加到请求末尾的当前时间消息。"""
|
||||
|
||||
return f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
@@ -446,7 +481,11 @@ class MaisakaChatLoopService:
|
||||
messages.append(llm_message)
|
||||
|
||||
normalized_injected_messages: List[Message] = []
|
||||
for injected_message in injected_user_messages or []:
|
||||
final_user_messages = [
|
||||
*(injected_user_messages or []),
|
||||
self._build_current_time_user_message(),
|
||||
]
|
||||
for injected_message in final_user_messages:
|
||||
normalized_message = str(injected_message or "").strip()
|
||||
if not normalized_message:
|
||||
continue
|
||||
@@ -458,31 +497,10 @@ class MaisakaChatLoopService:
|
||||
)
|
||||
|
||||
if normalized_injected_messages:
|
||||
insertion_index = self._resolve_injected_user_messages_insertion_index(messages)
|
||||
messages[insertion_index:insertion_index] = normalized_injected_messages
|
||||
messages.extend(normalized_injected_messages)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _resolve_injected_user_messages_insertion_index(messages: Sequence[Message]) -> int:
|
||||
"""计算 injected meta user messages 在请求中的插入位置。
|
||||
|
||||
规则与 deferred attachment 更接近:
|
||||
- 从尾部向前寻找最近的 stopping point;
|
||||
- stopping point 为 assistant 消息或 tool 结果消息;
|
||||
- 找到后插入到其后面;
|
||||
- 若不存在 stopping point,则退回到 system 消息之后。
|
||||
"""
|
||||
|
||||
for index in range(len(messages) - 1, -1, -1):
|
||||
message = messages[index]
|
||||
if message.role in {RoleType.Assistant, RoleType.Tool}:
|
||||
return index + 1
|
||||
|
||||
if messages and messages[0].role == RoleType.System:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
async def chat_loop_step(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
@@ -575,7 +593,8 @@ class MaisakaChatLoopService:
|
||||
tool_definitions=list(all_tools),
|
||||
)
|
||||
|
||||
generation_result = await self._llm_chat.generate_response_with_messages(
|
||||
llm_chat = self._get_llm_chat_client(request_kind)
|
||||
generation_result = await llm_chat.generate_response_with_messages(
|
||||
message_factory=message_factory,
|
||||
options=LLMGenerationOptions(
|
||||
tool_options=all_tools if all_tools else None,
|
||||
@@ -654,7 +673,11 @@ class MaisakaChatLoopService:
|
||||
chat_history,
|
||||
request_kind=request_kind,
|
||||
)
|
||||
effective_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
|
||||
base_context_size = max(1, int(max_context_size or global_config.chat.max_context_size))
|
||||
effective_context_size = max(
|
||||
base_context_size,
|
||||
int(base_context_size * CONTEXT_SELECTION_CACHE_STABILITY_RATIO),
|
||||
)
|
||||
selected_indices: List[int] = []
|
||||
counted_message_count = 0
|
||||
|
||||
@@ -690,9 +713,11 @@ class MaisakaChatLoopService:
|
||||
selected_history, _ = normalize_tool_result_order(selected_history)
|
||||
tool_message_count = sum(1 for message in selected_history if isinstance(message, ToolResultMessage))
|
||||
normal_message_count = len(selected_history) - tool_message_count
|
||||
stability_text = f"|cache_window {base_context_size}->{effective_context_size}"
|
||||
selection_reason = (
|
||||
f"实际发送 {len(selected_history)} 条消息"
|
||||
f"|消息 {normal_message_count} 条|tool {tool_message_count} 条"
|
||||
f"{stability_text}"
|
||||
)
|
||||
return (
|
||||
selected_history,
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
|
||||
from .context_messages import AssistantMessage, LLMContextMessage
|
||||
from .context_messages import LLMContextMessage
|
||||
from .history_utils import drop_leading_orphan_tool_results, drop_orphan_tool_results, normalize_tool_result_order
|
||||
|
||||
EARLY_TRIM_RATIO = 0.3
|
||||
TRIM_THRESHOLD_RATIO = 1.2
|
||||
TRIM_TARGET_RATIO = 1.0
|
||||
TRIM_THRESHOLD_RATIO = 2.0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -36,21 +36,16 @@ def process_chat_history_after_cycle(
|
||||
compact_removed_count = 0
|
||||
trim_threshold = ceil(max_context_size * TRIM_THRESHOLD_RATIO)
|
||||
if remaining_context_count > trim_threshold:
|
||||
removed_early_message_count = _remove_early_history_messages(processed_history)
|
||||
processed_history, removed_after_message_trim_count, moved_after_message_trim_count = (
|
||||
_normalize_history_structure(processed_history)
|
||||
target_context_count = max(1, int(max_context_size * TRIM_TARGET_RATIO))
|
||||
removed_early_message_count = _trim_history_to_context_target(
|
||||
processed_history,
|
||||
target_context_count=target_context_count,
|
||||
)
|
||||
removed_assistant_thought_count = _remove_early_assistant_thoughts(processed_history)
|
||||
processed_history, removed_after_thought_trim_count, moved_after_thought_trim_count = (
|
||||
_normalize_history_structure(processed_history)
|
||||
processed_history, removed_after_trim_count, moved_after_trim_count = _normalize_history_structure(
|
||||
processed_history
|
||||
)
|
||||
compact_removed_count = (
|
||||
removed_early_message_count
|
||||
+ removed_after_message_trim_count
|
||||
+ removed_assistant_thought_count
|
||||
+ removed_after_thought_trim_count
|
||||
)
|
||||
moved_tool_result_count += moved_after_message_trim_count + moved_after_thought_trim_count
|
||||
compact_removed_count = removed_early_message_count + removed_after_trim_count
|
||||
moved_tool_result_count += moved_after_trim_count
|
||||
|
||||
remaining_context_count = sum(1 for message in processed_history if message.count_in_context)
|
||||
removed_count = normalized_removed_count + compact_removed_count
|
||||
@@ -78,42 +73,27 @@ def _normalize_history_structure(
|
||||
)
|
||||
|
||||
|
||||
def _remove_early_history_messages(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 30% 的全部历史消息。"""
|
||||
def _trim_history_to_context_target(
|
||||
chat_history: list[LLMContextMessage],
|
||||
*,
|
||||
target_context_count: int,
|
||||
) -> int:
|
||||
"""移除最早的一段历史,直到普通上下文消息数量降到目标值以内。"""
|
||||
|
||||
remaining_context_count = sum(1 for message in chat_history if message.count_in_context)
|
||||
if remaining_context_count <= target_context_count:
|
||||
return 0
|
||||
|
||||
remove_count = 0
|
||||
for message in chat_history:
|
||||
remove_count += 1
|
||||
if message.count_in_context:
|
||||
remaining_context_count -= 1
|
||||
if remaining_context_count <= target_context_count:
|
||||
break
|
||||
|
||||
remove_count = int(len(chat_history) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
del chat_history[:remove_count]
|
||||
return remove_count
|
||||
|
||||
|
||||
def _remove_early_assistant_thoughts(chat_history: list[LLMContextMessage]) -> int:
|
||||
"""移除最早 30% 的非工具 assistant 思考内容。"""
|
||||
|
||||
candidate_indexes = [
|
||||
index
|
||||
for index, message in enumerate(chat_history)
|
||||
if isinstance(message, AssistantMessage)
|
||||
and not message.tool_calls
|
||||
and message.source_kind != "perception"
|
||||
and bool(message.content.strip())
|
||||
]
|
||||
remove_count = int(len(candidate_indexes) * EARLY_TRIM_RATIO)
|
||||
if remove_count <= 0:
|
||||
return 0
|
||||
|
||||
removed_indexes = set(candidate_indexes[:remove_count])
|
||||
filtered_history: list[LLMContextMessage] = []
|
||||
removed_total = 0
|
||||
for index, message in enumerate(chat_history):
|
||||
if index in removed_indexes:
|
||||
removed_total += 1
|
||||
continue
|
||||
filtered_history.append(message)
|
||||
|
||||
chat_history[:] = filtered_history
|
||||
return removed_total
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = get_logger("maisaka_reasoning_engine")
|
||||
|
||||
TIMING_GATE_CONTEXT_LIMIT = 24
|
||||
TIMING_GATE_CONTEXT_DROP_HEAD_RATIO = 0.7
|
||||
TIMING_GATE_MAX_TOKENS = 384
|
||||
TIMING_GATE_MAX_ATTEMPTS = 3
|
||||
TIMING_GATE_TOOL_NAMES = {"continue", "no_reply", "wait"}
|
||||
@@ -124,7 +124,6 @@ class MaisakaReasoningEngine:
|
||||
async def _run_timing_gate_sub_agent(
|
||||
self,
|
||||
*,
|
||||
context_message_limit: int,
|
||||
system_prompt: str,
|
||||
tool_definitions: list[dict[str, Any]],
|
||||
) -> Any:
|
||||
@@ -134,7 +133,10 @@ class MaisakaReasoningEngine:
|
||||
"""
|
||||
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
context_message_limit=self._runtime._max_context_size,
|
||||
drop_head_context_count=int(
|
||||
self._runtime._max_context_size * TIMING_GATE_CONTEXT_DROP_HEAD_RATIO,
|
||||
),
|
||||
system_prompt=system_prompt,
|
||||
request_kind="timing_gate",
|
||||
interrupt_flag=None,
|
||||
@@ -255,7 +257,6 @@ class MaisakaReasoningEngine:
|
||||
invalid_tool_text = ""
|
||||
for attempt_index in range(TIMING_GATE_MAX_ATTEMPTS):
|
||||
response = await self._run_timing_gate_sub_agent(
|
||||
context_message_limit=TIMING_GATE_CONTEXT_LIMIT,
|
||||
system_prompt=self._build_timing_gate_system_prompt(),
|
||||
tool_definitions=get_timing_tools(),
|
||||
)
|
||||
|
||||
@@ -45,6 +45,7 @@ from .context_messages import (
|
||||
from .display.display_utils import build_tool_call_summary_lines, format_token_count
|
||||
from .display.prompt_cli_renderer import PromptCLIVisualizer
|
||||
from .display.stage_status_board import remove_stage_status, update_stage_status
|
||||
from .history_utils import drop_leading_orphan_tool_results
|
||||
from .reasoning_engine import MaisakaReasoningEngine
|
||||
from .reply_effect import ReplyEffectTracker
|
||||
from .reply_effect.image_utils import extract_visual_attachments_from_sequence
|
||||
@@ -583,6 +584,7 @@ class MaisakaHeartFlowChatting:
|
||||
self,
|
||||
*,
|
||||
context_message_limit: int,
|
||||
drop_head_context_count: int = 0,
|
||||
system_prompt: str,
|
||||
request_kind: str = "sub_agent",
|
||||
extra_messages: Optional[Sequence[LLMContextMessage]] = None,
|
||||
@@ -598,7 +600,10 @@ class MaisakaHeartFlowChatting:
|
||||
request_kind=request_kind,
|
||||
max_context_size=context_message_limit,
|
||||
)
|
||||
sub_agent_history = list(selected_history)
|
||||
sub_agent_history = self._drop_head_context_messages(
|
||||
selected_history,
|
||||
drop_head_context_count,
|
||||
)
|
||||
if extra_messages:
|
||||
sub_agent_history.extend(list(extra_messages))
|
||||
|
||||
@@ -616,6 +621,31 @@ class MaisakaHeartFlowChatting:
|
||||
tool_definitions=[] if tool_definitions is None else tool_definitions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _drop_head_context_messages(
|
||||
chat_history: Sequence[LLMContextMessage],
|
||||
drop_context_count: int,
|
||||
) -> list[LLMContextMessage]:
|
||||
"""从已选上下文头部丢弃指定数量的普通上下文消息。"""
|
||||
|
||||
if drop_context_count <= 0:
|
||||
return list(chat_history)
|
||||
|
||||
first_kept_index = 0
|
||||
dropped_context_count = 0
|
||||
while (
|
||||
first_kept_index < len(chat_history)
|
||||
and dropped_context_count < drop_context_count
|
||||
):
|
||||
message = chat_history[first_kept_index]
|
||||
if message.count_in_context:
|
||||
dropped_context_count += 1
|
||||
first_kept_index += 1
|
||||
|
||||
trimmed_history = list(chat_history[first_kept_index:])
|
||||
trimmed_history, _ = drop_leading_orphan_tool_results(trimmed_history)
|
||||
return trimmed_history
|
||||
|
||||
async def _run_reply_effect_judge(self, prompt: str) -> str:
|
||||
"""运行回复效果观察器使用的临时 LLM 评审。"""
|
||||
|
||||
|
||||
1520
src/services/llm_cache_stats.py
Normal file
1520
src/services/llm_cache_stats.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,8 @@
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
@@ -26,6 +28,7 @@ from src.llm_models.payload_content.message import Message, MessageBuilder, Role
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
from src.services.llm_cache_stats import record_llm_cache_usage
|
||||
from src.services.service_task_resolver import (
|
||||
get_available_models as _get_available_models,
|
||||
resolve_task_name as _resolve_task_name,
|
||||
@@ -46,7 +49,7 @@ class LLMServiceClient:
|
||||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
|
||||
"""初始化 LLM 服务门面。
|
||||
|
||||
Args:
|
||||
@@ -55,6 +58,7 @@ class LLMServiceClient:
|
||||
"""
|
||||
self.task_name = _resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self.session_id = str(session_id or "").strip()
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@staticmethod
|
||||
@@ -85,6 +89,70 @@ class LLMServiceClient:
|
||||
return LLMImageOptions()
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
|
||||
parts: list[dict[str, Any]] = []
|
||||
for part in message.parts:
|
||||
if hasattr(part, "text"):
|
||||
parts.append({"type": "text", "text": part.text})
|
||||
continue
|
||||
|
||||
image_base64 = getattr(part, "image_base64", "")
|
||||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||
parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"format": getattr(part, "image_format", ""),
|
||||
"size": len(image_base64),
|
||||
"sha256": image_digest,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"role": str(message.role.value if hasattr(message.role, "value") else message.role),
|
||||
"parts": parts,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"tool_name": message.tool_name,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"name": tool_call.func_name,
|
||||
"arguments": tool_call.args,
|
||||
"extra_content": tool_call.extra_content,
|
||||
}
|
||||
for tool_call in (message.tool_calls or [])
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_cache_stats_prompt_text(
|
||||
cls,
|
||||
*,
|
||||
messages: List[Message],
|
||||
tool_options: Any,
|
||||
response_format: Any,
|
||||
) -> str:
|
||||
payload = {
|
||||
"messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
|
||||
"tool_options": tool_options or [],
|
||||
"response_format": response_format,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
|
||||
"""记录当前调用的 prompt cache 统计。"""
|
||||
|
||||
record_llm_cache_usage(
|
||||
task_name=self.task_name,
|
||||
request_type=self.request_type,
|
||||
model_name=result.model_name,
|
||||
session_id=self.session_id,
|
||||
prompt_tokens=result.prompt_tokens,
|
||||
prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
|
||||
prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
|
||||
prompt_text=prompt_text,
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -100,7 +168,12 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_async(
|
||||
prompt_text = self._build_cache_stats_prompt_text(
|
||||
messages=[MessageBuilder().add_text_content(prompt).build()],
|
||||
tool_options=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
)
|
||||
result = await self._orchestrator.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
@@ -109,6 +182,8 @@ class LLMServiceClient:
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||
return result
|
||||
|
||||
async def generate_response_with_messages(
|
||||
self,
|
||||
@@ -125,8 +200,22 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
prompt_text_holder: dict[str, str] = {}
|
||||
|
||||
def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
|
||||
if len(inspect.signature(message_factory).parameters) >= 2:
|
||||
messages = message_factory(client, model_info)
|
||||
else:
|
||||
messages = message_factory(client)
|
||||
prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
|
||||
messages=messages,
|
||||
tool_options=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
)
|
||||
return messages
|
||||
|
||||
result = await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=cache_stats_message_factory,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
@@ -134,6 +223,8 @@ class LLMServiceClient:
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
|
||||
return result
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
@@ -154,7 +245,30 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_image_options(options)
|
||||
return await self._orchestrator.generate_response_for_image(
|
||||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||
prompt_text = json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image",
|
||||
"format": image_format,
|
||||
"size": len(image_base64),
|
||||
"sha256": image_digest,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tool_options": [],
|
||||
"response_format": None,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
result = await self._orchestrator.generate_response_for_image(
|
||||
prompt=prompt,
|
||||
image_base64=image_base64,
|
||||
image_format=image_format,
|
||||
@@ -162,6 +276,8 @@ class LLMServiceClient:
|
||||
max_tokens=active_options.max_tokens,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||
return result
|
||||
|
||||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||
"""执行音频转写请求。
|
||||
|
||||
@@ -70,11 +70,15 @@ class ConfigSchemaGenerator:
|
||||
) -> Dict[str, Any]:
|
||||
field_docs = config_class.get_class_field_docs()
|
||||
field_type = cls._map_field_type(annotation)
|
||||
raw_description = field_docs.get(field_name, field_info.description or "")
|
||||
# `_wrap_` 标记在配置类 docstring 中表示该说明应作为块级注释(独立成行)
|
||||
# 在前端展示时把它转为换行符,使描述以新行起始或在中间换行
|
||||
description = raw_description.replace("_wrap_", "\n").strip("\n")
|
||||
schema: Dict[str, Any] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"label": field_name,
|
||||
"description": field_docs.get(field_name, field_info.description or ""),
|
||||
"description": description,
|
||||
"required": field_info.is_required(),
|
||||
}
|
||||
|
||||
|
||||
@@ -13,10 +13,10 @@ from src.config.config import global_config
|
||||
from src.webui.dependencies import require_auth
|
||||
|
||||
from .service import (
|
||||
WEBUI_CHAT_GROUP_ID,
|
||||
WEBUI_CHAT_PLATFORM,
|
||||
chat_history,
|
||||
chat_manager,
|
||||
normalize_webui_user_id,
|
||||
)
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
@@ -30,10 +30,15 @@ async def get_chat_history(
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> Dict[str, object]:
|
||||
"""获取聊天历史记录。"""
|
||||
del user_id
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
"""获取聊天历史记录。
|
||||
|
||||
优先按 ``group_id`` 加载虚拟群聊历史;未提供时使用规范化后的 ``user_id`` 加载 WebUI 私聊历史。
|
||||
"""
|
||||
if group_id:
|
||||
history = chat_history.get_history(limit, group_id=group_id)
|
||||
else:
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
history = chat_history.get_history(limit, user_id=normalized_user_id)
|
||||
return {"success": True, "messages": history, "total": len(history)}
|
||||
|
||||
|
||||
@@ -100,10 +105,18 @@ async def get_persons_by_platform(
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None),
|
||||
) -> Dict[str, object]:
|
||||
"""清空聊天历史记录。"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
"""清空聊天历史记录。
|
||||
|
||||
优先按 ``group_id`` 清理虚拟群聊历史;未提供时使用规范化后的 ``user_id`` 清理 WebUI 私聊历史。
|
||||
"""
|
||||
if group_id:
|
||||
deleted = chat_history.clear_history(group_id=group_id)
|
||||
else:
|
||||
normalized_user_id = normalize_webui_user_id(user_id)
|
||||
deleted = chat_history.clear_history(user_id=normalized_user_id)
|
||||
return {"success": True, "message": f"已清空 {deleted} 条聊天记录"}
|
||||
|
||||
|
||||
@@ -113,6 +126,5 @@ async def get_chat_info() -> Dict[str, object]:
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ from src.common.message_repository import find_messages
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
|
||||
from .serializers import serialize_message_sequence
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
@@ -61,7 +63,7 @@ class ChatSessionConnection:
|
||||
client_session_id: str
|
||||
user_id: str
|
||||
user_name: str
|
||||
active_group_id: str
|
||||
channel_key: str
|
||||
virtual_config: Optional[VirtualIdentityConfig]
|
||||
sender: AsyncMessageSender
|
||||
|
||||
@@ -92,6 +94,21 @@ class ChatHistoryManager:
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
|
||||
# 将存库中的 raw_message 序列化为前端可识别的富文本消息段,
|
||||
# 避免“刚刚收到的机器人回复是富文本,刷新后变成纯文本”的体验不一致。
|
||||
segments: List[Dict[str, Any]] = []
|
||||
try:
|
||||
raw_message = getattr(msg, "raw_message", None)
|
||||
if raw_message is not None and getattr(raw_message, "components", None):
|
||||
segments = serialize_message_sequence(raw_message)
|
||||
except Exception as exc: # 仅记录警告,退化为纯文本
|
||||
logger.debug(f"序列化历史消息段失败,退化为纯文本: {exc}")
|
||||
segments = []
|
||||
|
||||
is_rich = bool(segments) and not (
|
||||
len(segments) == 1 and segments[0].get("type") == "text"
|
||||
)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
@@ -100,32 +117,119 @@ class ChatHistoryManager:
|
||||
"sender_name": user_info.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
"message_type": "rich" if is_rich else "text",
|
||||
"segments": segments if is_rich else None,
|
||||
}
|
||||
|
||||
def _resolve_session_id(self, group_id: Optional[str]) -> str:
|
||||
"""根据群组标识解析聊天会话 ID。
|
||||
def _enrich_reply_segments(
|
||||
self,
|
||||
segments: List[Dict[str, Any]],
|
||||
message_index: Dict[str, SessionMessage],
|
||||
session_id: Optional[str],
|
||||
) -> None:
|
||||
"""回填历史消息中 reply 段缺失的发送者/原内容字段。
|
||||
|
||||
DB 中持久化的 ReplyComponent 通常只保留了 ``target_message_id``,
|
||||
``target_message_content`` / ``target_message_sender_*`` 字段为空。
|
||||
这里基于当前会话已加载的消息列表(必要时回查数据库)进行补全。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
segments: 单条历史消息的消息段列表,原地修改。
|
||||
message_index: 当前会话已加载消息的 ``message_id -> SessionMessage`` 索引。
|
||||
session_id: 当前会话 ID,用于按 ID 单查时缩小范围。
|
||||
"""
|
||||
for segment in segments:
|
||||
if not isinstance(segment, dict) or segment.get("type") != "reply":
|
||||
continue
|
||||
data = segment.get("data")
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
target_message_id = data.get("target_message_id")
|
||||
if not target_message_id:
|
||||
continue
|
||||
|
||||
has_content = bool(str(data.get("target_message_content") or "").strip())
|
||||
has_sender = any(
|
||||
str(data.get(key) or "").strip()
|
||||
for key in (
|
||||
"target_message_sender_id",
|
||||
"target_message_sender_nickname",
|
||||
"target_message_sender_cardname",
|
||||
)
|
||||
)
|
||||
if has_content and has_sender:
|
||||
continue
|
||||
|
||||
target_msg = message_index.get(str(target_message_id))
|
||||
if target_msg is None:
|
||||
# 退化为按 ID 单查(仅当不在当前窗口内时才付出 DB 代价)
|
||||
try:
|
||||
from src.services.message_service import get_message_by_id
|
||||
|
||||
target_msg = get_message_by_id(str(target_message_id), session_id or None)
|
||||
except Exception as exc:
|
||||
logger.debug(f"按 ID 回查 reply 目标消息失败: {exc}")
|
||||
target_msg = None
|
||||
if target_msg is None:
|
||||
continue
|
||||
|
||||
user_info = target_msg.message_info.user_info
|
||||
if not has_content:
|
||||
content_text = (
|
||||
target_msg.processed_plain_text
|
||||
or target_msg.display_message
|
||||
or ""
|
||||
)
|
||||
data["target_message_content"] = content_text
|
||||
if not has_sender:
|
||||
data["target_message_sender_id"] = user_info.user_id or ""
|
||||
data["target_message_sender_nickname"] = user_info.user_nickname or ""
|
||||
data["target_message_sender_cardname"] = (
|
||||
getattr(user_info, "user_cardname", "") or ""
|
||||
)
|
||||
|
||||
def _resolve_session_id(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""根据会话标识解析内部聊天会话 ID。
|
||||
|
||||
优先按虚拟群聊解析;否则按 WebUI 私聊解析。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识(虚拟群聊模式)。
|
||||
user_id: 用户标识(私聊模式)。
|
||||
|
||||
Returns:
|
||||
str: 内部聊天会话 ID。
|
||||
Optional[str]: 内部聊天会话 ID;当 group_id 与 user_id 均未提供时返回 ``None``。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=target_group_id)
|
||||
if group_id:
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, group_id=group_id)
|
||||
if user_id:
|
||||
return SessionUtils.calculate_session_id(WEBUI_CHAT_PLATFORM, user_id=user_id)
|
||||
return None
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
def get_history(
|
||||
self,
|
||||
limit: int = 50,
|
||||
group_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
limit: 最大返回条数。
|
||||
group_id: 群组标识。
|
||||
group_id: 群组标识(虚拟群聊模式)。
|
||||
user_id: 用户标识(私聊模式)。
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 历史消息列表。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
session_id = self._resolve_session_id(group_id=group_id, user_id=user_id)
|
||||
if session_id is None:
|
||||
logger.debug("获取聊天历史时缺少 group_id 与 user_id,返回空列表")
|
||||
return []
|
||||
try:
|
||||
messages = find_messages(
|
||||
session_id=session_id,
|
||||
@@ -133,30 +237,54 @@ class ChatHistoryManager:
|
||||
limit_mode="latest",
|
||||
filter_command=False,
|
||||
)
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
# 构建 message_id -> SessionMessage 索引,用于回填历史中 reply 段的发送者/内容
|
||||
# (DB 中通常只存了 target_message_id,target_message_content/sender_* 缺失)。
|
||||
message_index: Dict[str, SessionMessage] = {}
|
||||
for m in messages:
|
||||
mid = getattr(m, "message_id", None)
|
||||
if mid:
|
||||
message_index[str(mid)] = m
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
item = self._message_to_dict(msg, group_id)
|
||||
segments = item.get("segments")
|
||||
if segments:
|
||||
self._enrich_reply_segments(segments, message_index, session_id)
|
||||
result.append(item)
|
||||
logger.debug(
|
||||
f"从数据库加载了 {len(result)} 条聊天记录 (group_id={group_id}, user_id={user_id})"
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.error(f"从数据库加载聊天记录失败: {exc}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
def clear_history(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""清空指定会话的历史消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
group_id: 群组标识(虚拟群聊模式)。
|
||||
user_id: 用户标识(私聊模式)。
|
||||
|
||||
Returns:
|
||||
int: 被删除的消息数量。
|
||||
"""
|
||||
target_group_id = group_id or WEBUI_CHAT_GROUP_ID
|
||||
session_id = self._resolve_session_id(target_group_id)
|
||||
session_id = self._resolve_session_id(group_id=group_id, user_id=user_id)
|
||||
if session_id is None:
|
||||
return 0
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = delete(Messages).where(col(Messages.session_id) == session_id)
|
||||
result = session.exec(statement)
|
||||
deleted = result.rowcount or 0
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
logger.info(
|
||||
f"已清空 {deleted} 条聊天记录 (group_id={group_id}, user_id={user_id})"
|
||||
)
|
||||
return deleted
|
||||
except Exception as exc:
|
||||
logger.error(f"清空聊天记录失败: {exc}")
|
||||
@@ -174,30 +302,30 @@ class ChatConnectionManager:
|
||||
self.group_sessions: Dict[str, Set[str]] = {}
|
||||
self.user_sessions: Dict[str, Set[str]] = {}
|
||||
|
||||
def _bind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""为会话绑定群组索引。
|
||||
def _bind_channel(self, session_id: str, channel_key: str) -> None:
|
||||
"""为会话绑定逻辑频道索引。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
channel_key: 频道键(``group:<gid>`` 或 ``private:<uid>``)。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.setdefault(group_id, set())
|
||||
group_session_ids.add(session_id)
|
||||
channel_session_ids = self.group_sessions.setdefault(channel_key, set())
|
||||
channel_session_ids.add(session_id)
|
||||
|
||||
def _unbind_group(self, session_id: str, group_id: str) -> None:
|
||||
"""移除会话与群组的索引关系。
|
||||
def _unbind_channel(self, session_id: str, channel_key: str) -> None:
|
||||
"""移除会话与逻辑频道的索引关系。
|
||||
|
||||
Args:
|
||||
session_id: 内部会话 ID。
|
||||
group_id: 群组标识。
|
||||
channel_key: 频道键。
|
||||
"""
|
||||
group_session_ids = self.group_sessions.get(group_id)
|
||||
if group_session_ids is None:
|
||||
channel_session_ids = self.group_sessions.get(channel_key)
|
||||
if channel_session_ids is None:
|
||||
return
|
||||
|
||||
group_session_ids.discard(session_id)
|
||||
if not group_session_ids:
|
||||
del self.group_sessions[group_id]
|
||||
channel_session_ids.discard(session_id)
|
||||
if not channel_session_ids:
|
||||
del self.group_sessions[channel_key]
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
@@ -220,18 +348,39 @@ class ChatConnectionManager:
|
||||
virtual_config: 当前虚拟身份配置。
|
||||
sender: 发送消息到前端的异步回调。
|
||||
"""
|
||||
channel_key = compute_channel_key(virtual_config, user_id)
|
||||
existing_session_id = self.client_sessions.get((connection_id, client_session_id))
|
||||
if existing_session_id is not None and existing_session_id == session_id:
|
||||
# 同一物理连接 + 前端会话重复打开(常见于 React StrictMode 双挂载或客户端去抖失败),
|
||||
# 直接复用现有会话并仅刷新可变字段,避免反复断开/重连产生噪声日志。
|
||||
existing = self.active_connections.get(existing_session_id)
|
||||
if existing is not None:
|
||||
if existing.channel_key != channel_key:
|
||||
self._unbind_channel(existing_session_id, existing.channel_key)
|
||||
self._bind_channel(existing_session_id, channel_key)
|
||||
existing.channel_key = channel_key
|
||||
existing.user_id = user_id
|
||||
existing.user_name = user_name
|
||||
existing.virtual_config = virtual_config
|
||||
existing.sender = sender
|
||||
logger.debug(
|
||||
"WebUI 聊天会话复用: session=%s, connection=%s, client_session=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
channel_key,
|
||||
)
|
||||
return
|
||||
if existing_session_id is not None:
|
||||
self.disconnect(existing_session_id)
|
||||
|
||||
active_group_id = get_current_group_id(virtual_config)
|
||||
session_connection = ChatSessionConnection(
|
||||
session_id=session_id,
|
||||
connection_id=connection_id,
|
||||
client_session_id=client_session_id,
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
active_group_id=active_group_id,
|
||||
channel_key=channel_key,
|
||||
virtual_config=virtual_config,
|
||||
sender=sender,
|
||||
)
|
||||
@@ -240,14 +389,14 @@ class ChatConnectionManager:
|
||||
self.client_sessions[(connection_id, client_session_id)] = session_id
|
||||
self.connection_sessions.setdefault(connection_id, set()).add(session_id)
|
||||
self.user_sessions.setdefault(user_id, set()).add(session_id)
|
||||
self._bind_group(session_id, active_group_id)
|
||||
self._bind_channel(session_id, channel_key)
|
||||
logger.info(
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, group=%s",
|
||||
"WebUI 聊天会话已连接: session=%s, connection=%s, client_session=%s, user=%s, channel=%s",
|
||||
session_id,
|
||||
connection_id,
|
||||
client_session_id,
|
||||
user_id,
|
||||
active_group_id,
|
||||
channel_key,
|
||||
)
|
||||
|
||||
def disconnect(self, session_id: str) -> None:
|
||||
@@ -261,7 +410,7 @@ class ChatConnectionManager:
|
||||
return
|
||||
|
||||
self.client_sessions.pop((session_connection.connection_id, session_connection.client_session_id), None)
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._unbind_channel(session_id, session_connection.channel_key)
|
||||
|
||||
connection_session_ids = self.connection_sessions.get(session_connection.connection_id)
|
||||
if connection_session_ids is not None:
|
||||
@@ -327,11 +476,11 @@ class ChatConnectionManager:
|
||||
if session_connection is None:
|
||||
return
|
||||
|
||||
next_group_id = get_current_group_id(virtual_config)
|
||||
if next_group_id != session_connection.active_group_id:
|
||||
self._unbind_group(session_id, session_connection.active_group_id)
|
||||
self._bind_group(session_id, next_group_id)
|
||||
session_connection.active_group_id = next_group_id
|
||||
next_channel_key = compute_channel_key(virtual_config, session_connection.user_id)
|
||||
if next_channel_key != session_connection.channel_key:
|
||||
self._unbind_channel(session_id, session_connection.channel_key)
|
||||
self._bind_channel(session_id, next_channel_key)
|
||||
session_connection.channel_key = next_channel_key
|
||||
|
||||
session_connection.user_name = user_name
|
||||
session_connection.virtual_config = virtual_config
|
||||
@@ -361,16 +510,40 @@ class ChatConnectionManager:
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(self, group_id: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定群组下的全部逻辑会话广播消息。
|
||||
async def broadcast_to_channel(self, channel_key: str, message: Dict[str, Any]) -> None:
|
||||
"""向指定逻辑频道下的全部会话广播消息。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识。
|
||||
channel_key: 频道键(``group:<gid>`` 或 ``private:<uid>``)。
|
||||
message: 待广播的消息内容。
|
||||
"""
|
||||
for session_id in list(self.group_sessions.get(group_id, set())):
|
||||
for session_id in list(self.group_sessions.get(channel_key, set())):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
async def broadcast_to_group(
|
||||
self,
|
||||
group_id: Optional[str],
|
||||
message: Dict[str, Any],
|
||||
*,
|
||||
user_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""向指定群组或私聊会话广播消息。
|
||||
|
||||
当 ``group_id`` 非空时按群聊广播;否则按 ``user_id`` 私聊广播。
|
||||
|
||||
Args:
|
||||
group_id: 群组标识;为空时使用 ``user_id``。
|
||||
message: 待广播的消息内容。
|
||||
user_id: 私聊接收方用户 ID。
|
||||
"""
|
||||
if group_id:
|
||||
channel_key = f"group:{group_id}"
|
||||
elif user_id:
|
||||
channel_key = f"private:{user_id}"
|
||||
else:
|
||||
return
|
||||
await self.broadcast_to_channel(channel_key, message)
|
||||
|
||||
|
||||
chat_history = ChatHistoryManager()
|
||||
chat_manager = ChatConnectionManager()
|
||||
@@ -388,6 +561,24 @@ def is_virtual_mode_enabled(virtual_config: Optional[VirtualIdentityConfig]) ->
|
||||
return bool(virtual_config and virtual_config.enabled)
|
||||
|
||||
|
||||
def compute_channel_key(virtual_config: Optional[VirtualIdentityConfig], user_id: str) -> str:
|
||||
"""计算当前会话的逻辑频道键。
|
||||
|
||||
虚拟身份启用时使用虚拟群聊 ID,否则使用当前 WebUI 用户 ID 作为私聊频道。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
user_id: 当前 WebUI 用户 ID。
|
||||
|
||||
Returns:
|
||||
str: 频道键,格式为 ``group:<gid>`` 或 ``private:<uid>``。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
return f"group:{virtual_config.group_id}"
|
||||
return f"private:{user_id}"
|
||||
|
||||
|
||||
def normalize_webui_user_id(user_id: Optional[str]) -> str:
|
||||
"""标准化 WebUI 用户 ID。
|
||||
|
||||
@@ -500,6 +691,8 @@ def build_session_info_message(
|
||||
Returns:
|
||||
Dict[str, Any]: 会话信息消息。
|
||||
"""
|
||||
# bot_qq 用于前端从 QQ 头像公开接口拉取机器人头像(qq_account == 0 表示未配置,不推送)。
|
||||
bot_qq_account = int(getattr(global_config.bot, "qq_account", 0) or 0)
|
||||
session_info_data: Dict[str, Any] = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
@@ -507,6 +700,8 @@ def build_session_info_message(
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
if bot_qq_account > 0:
|
||||
session_info_data["bot_qq"] = str(bot_qq_account)
|
||||
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
@@ -529,7 +724,7 @@ def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig])
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID。
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID;否则返回 ``None`` 表示使用私聊。
|
||||
"""
|
||||
if is_virtual_mode_enabled(virtual_config):
|
||||
assert virtual_config is not None
|
||||
@@ -537,16 +732,16 @@ def get_active_history_group_id(virtual_config: Optional[VirtualIdentityConfig])
|
||||
return None
|
||||
|
||||
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
def get_current_group_id(virtual_config: Optional[VirtualIdentityConfig]) -> Optional[str]:
|
||||
"""获取当前会话的有效群组 ID。
|
||||
|
||||
Args:
|
||||
virtual_config: 虚拟身份配置。
|
||||
|
||||
Returns:
|
||||
str: 当前会话应使用的群组 ID。
|
||||
Optional[str]: 虚拟身份启用时返回对应群组 ID;否则返回 ``None``(默认私聊模式)。
|
||||
"""
|
||||
return get_active_history_group_id(virtual_config) or WEBUI_CHAT_GROUP_ID
|
||||
return get_active_history_group_id(virtual_config)
|
||||
|
||||
|
||||
def build_welcome_message(virtual_config: Optional[VirtualIdentityConfig]) -> str:
|
||||
@@ -611,7 +806,12 @@ async def send_initial_chat_state(
|
||||
)
|
||||
|
||||
history_group_id = get_active_history_group_id(virtual_config)
|
||||
history = chat_history.get_history(50, history_group_id)
|
||||
history_user_id = None if history_group_id else user_id
|
||||
history = chat_history.get_history(
|
||||
50,
|
||||
group_id=history_group_id,
|
||||
user_id=history_user_id,
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
@@ -679,37 +879,42 @@ def create_message_data(
|
||||
|
||||
if virtual_config and virtual_config.enabled:
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
group_id: Optional[str] = (
|
||||
virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
group_name: Optional[str] = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
actual_user_nickname = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
group_id = None
|
||||
group_name = None
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
actual_user_nickname = user_name
|
||||
|
||||
message_info: Dict[str, Any] = {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_nickname,
|
||||
"user_cardname": actual_user_nickname,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
}
|
||||
if group_id is not None:
|
||||
message_info["group_info"] = {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
}
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_info": message_info,
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
@@ -717,10 +922,6 @@ def create_message_data(
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
@@ -776,6 +977,7 @@ async def handle_chat_message(
|
||||
},
|
||||
"virtual_mode": is_virtual_mode_enabled(current_virtual_config),
|
||||
},
|
||||
user_id=normalized_user_id,
|
||||
)
|
||||
|
||||
message_data = create_message_data(
|
||||
@@ -788,13 +990,21 @@ async def handle_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": True})
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{"type": "typing", "is_typing": True},
|
||||
user_id=normalized_user_id,
|
||||
)
|
||||
await chat_bot.message_process(message_data)
|
||||
except Exception as exc:
|
||||
logger.error(f"处理消息时出错: {exc}")
|
||||
await send_chat_error(session_id, f"处理消息时出错: {str(exc)}")
|
||||
finally:
|
||||
await chat_manager.broadcast_to_group(target_group_id, {"type": "typing", "is_typing": False})
|
||||
await chat_manager.broadcast_to_group(
|
||||
target_group_id,
|
||||
{"type": "typing", "is_typing": False},
|
||||
user_id=normalized_user_id,
|
||||
)
|
||||
|
||||
return next_user_name
|
||||
|
||||
@@ -915,11 +1125,12 @@ async def enable_virtual_identity(
|
||||
return None
|
||||
|
||||
|
||||
async def disable_virtual_identity(session_id: str) -> None:
|
||||
async def disable_virtual_identity(session_id: str, normalized_user_id: str) -> None:
|
||||
"""关闭虚拟身份模式。
|
||||
|
||||
Args:
|
||||
session_id: 内部逻辑会话 ID。
|
||||
normalized_user_id: 规范化后的 WebUI 用户 ID,用于加载私聊历史。
|
||||
"""
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
@@ -933,8 +1144,8 @@ async def disable_virtual_identity(session_id: str) -> None:
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": chat_history.get_history(50, WEBUI_CHAT_GROUP_ID),
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"messages": chat_history.get_history(50, user_id=normalized_user_id),
|
||||
"group_id": None,
|
||||
},
|
||||
)
|
||||
await chat_manager.send_message(
|
||||
@@ -952,6 +1163,7 @@ async def handle_virtual_identity_update(
|
||||
session_id_prefix: str,
|
||||
data: Dict[str, Any],
|
||||
current_virtual_config: Optional[VirtualIdentityConfig],
|
||||
normalized_user_id: str,
|
||||
) -> Optional[VirtualIdentityConfig]:
|
||||
"""处理虚拟身份切换请求。
|
||||
|
||||
@@ -960,6 +1172,7 @@ async def handle_virtual_identity_update(
|
||||
session_id_prefix: 会话前缀。
|
||||
data: 前端提交的数据。
|
||||
current_virtual_config: 当前虚拟身份配置。
|
||||
normalized_user_id: 规范化后的 WebUI 用户 ID。
|
||||
|
||||
Returns:
|
||||
Optional[VirtualIdentityConfig]: 更新后的虚拟身份配置。
|
||||
@@ -969,7 +1182,7 @@ async def handle_virtual_identity_update(
|
||||
next_config = await enable_virtual_identity(session_id, session_id_prefix, virtual_data)
|
||||
return next_config if next_config is not None else current_virtual_config
|
||||
|
||||
await disable_virtual_identity(session_id)
|
||||
await disable_virtual_identity(session_id, normalized_user_id)
|
||||
return None
|
||||
|
||||
|
||||
@@ -1019,6 +1232,7 @@ async def dispatch_chat_event(
|
||||
session_id_prefix=session_id_prefix,
|
||||
data=data,
|
||||
current_virtual_config=current_virtual_config,
|
||||
normalized_user_id=normalized_user_id,
|
||||
)
|
||||
return current_user_name, next_virtual_config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user