merge: 同步 upstream/r-dev 并解决冲突
This commit is contained in:
@@ -1,25 +1,28 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import heapq
|
||||
import Levenshtein
|
||||
import random
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import Levenshtein
|
||||
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
|
||||
from src.common.database.database import get_db_session, get_db_session_manual
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.config.config import config_manager, global_config
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
logger = get_logger("emoji")
|
||||
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
|
||||
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册表情包系统内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
emoji_schema = {
|
||||
"type": "object",
|
||||
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
|
||||
}
|
||||
string_array_schema = {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
}
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="emoji.maisaka.before_select",
|
||||
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.maisaka.after_select",
|
||||
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"stream_id": {"type": "string", "description": "目标会话 ID。"},
|
||||
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
|
||||
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
|
||||
"context_texts": {
|
||||
**string_array_schema,
|
||||
"description": "最近聊天上下文文本列表。",
|
||||
},
|
||||
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
|
||||
"selected_emoji": emoji_schema,
|
||||
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
|
||||
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
|
||||
"abort_message": {
|
||||
"type": "string",
|
||||
"description": "当 Hook 主动中止时可附带的失败提示。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"stream_id",
|
||||
"requested_emotion",
|
||||
"reasoning",
|
||||
"context_texts",
|
||||
"sample_size",
|
||||
"matched_emotion",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_description",
|
||||
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
|
||||
"image_format": {"type": "string", "description": "表情图片格式。"},
|
||||
},
|
||||
required=["emoji", "description", "image_format"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
|
||||
"""将表情包对象序列化为 Hook 可传输载荷。
|
||||
|
||||
Args:
|
||||
emoji: 待序列化的表情包对象。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
|
||||
"""
|
||||
|
||||
if emoji is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"file_hash": str(emoji.file_hash or "").strip(),
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
|
||||
def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
"""将任意列表值规范化为字符串列表。
|
||||
|
||||
Args:
|
||||
raw_values: 待规范化的原始值。
|
||||
|
||||
Returns:
|
||||
List[str]: 去空白后的字符串列表。
|
||||
"""
|
||||
|
||||
if not isinstance(raw_values, list):
|
||||
return []
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -642,6 +810,22 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
@@ -687,6 +871,23 @@ class EmojiManager:
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=target_emoji.description,
|
||||
emotions=list(emotions),
|
||||
)
|
||||
if hook_result.aborted:
|
||||
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
|
||||
target_emoji.emotion = emotions
|
||||
return True, target_emoji
|
||||
|
||||
349
src/chat/emoji_system/maisaka_tool.py
Normal file
349
src/chat/emoji_system/maisaka_tool.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""Maisaka 表情工具内置能力。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import random
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
|
||||
from src.common.data_models.image_data_model import MaiEmoji
|
||||
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MaisakaEmojiSendResult:
|
||||
"""Maisaka 表情发送结果。"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
emoji_base64: str = ""
|
||||
description: str = ""
|
||||
emotions: list[str] = field(default_factory=list)
|
||||
requested_emotion: str = ""
|
||||
matched_emotion: str = ""
|
||||
|
||||
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
|
||||
def _coerce_positive_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为正整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时使用的默认值。
|
||||
|
||||
Returns:
|
||||
int: 规范化后的正整数。
|
||||
"""
|
||||
|
||||
try:
|
||||
normalized_value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return normalized_value if normalized_value > 0 else default
|
||||
|
||||
|
||||
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
|
||||
"""清洗 Hook 和调用链传入的上下文文本列表。
|
||||
|
||||
Args:
|
||||
context_texts: 原始上下文文本序列。
|
||||
|
||||
Returns:
|
||||
list[str]: 过滤空白后的上下文文本列表。
|
||||
"""
|
||||
|
||||
if not context_texts:
|
||||
return []
|
||||
return [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
|
||||
|
||||
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
"""根据 Hook 返回值解析目标表情包对象。
|
||||
|
||||
Args:
|
||||
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
|
||||
|
||||
Returns:
|
||||
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
|
||||
"""
|
||||
|
||||
raw_hash: str = ""
|
||||
if isinstance(raw_value, dict):
|
||||
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
|
||||
elif isinstance(raw_value, str):
|
||||
raw_hash = raw_value.strip()
|
||||
|
||||
if not raw_hash:
|
||||
return None
|
||||
|
||||
for emoji in emoji_manager.emojis:
|
||||
if emoji.file_hash == raw_hash:
|
||||
return emoji
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
|
||||
|
||||
|
||||
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
|
||||
"""构建供情绪判断使用的最近上下文文本。"""
|
||||
|
||||
normalized_items = [str(item).strip() for item in context_texts if str(item).strip()]
|
||||
if not normalized_items:
|
||||
return ""
|
||||
return "\n".join(normalized_items[-max_items:])
|
||||
|
||||
|
||||
async def _select_emoji_with_llm(
|
||||
*,
|
||||
sampled_emojis: Sequence[MaiEmoji],
|
||||
reasoning: str,
|
||||
context_text: str,
|
||||
) -> tuple[MaiEmoji, str]:
|
||||
"""让模型在采样表情中选择更合适的情绪标签。"""
|
||||
|
||||
emotion_map: dict[str, list[MaiEmoji]] = {}
|
||||
for emoji in sampled_emojis:
|
||||
for emotion in _normalize_emotions(emoji):
|
||||
emotion_map.setdefault(emotion, []).append(emoji)
|
||||
|
||||
available_emotions = list(emotion_map.keys())
|
||||
if not available_emotions:
|
||||
return random.choice(list(sampled_emojis)), ""
|
||||
|
||||
prompt = (
|
||||
"你正在为聊天场景选择一个最合适的表情包情绪标签。\n"
|
||||
f"发送原因:{reasoning or '辅助表达当前语气和情绪'}\n"
|
||||
f"最近聊天记录:\n{context_text or '(暂无额外上下文)'}\n\n"
|
||||
"可选情绪标签如下:\n"
|
||||
f"{chr(10).join(available_emotions)}\n\n"
|
||||
"请只返回一个最匹配的情绪标签,不要解释。"
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=60),
|
||||
)
|
||||
chosen_emotion = (llm_result.response or "").strip().strip("\"'")
|
||||
except Exception as exc:
|
||||
logger.warning(f"使用 LLM 选择表情情绪失败,将回退为随机选择: {exc}")
|
||||
chosen_emotion = ""
|
||||
|
||||
if chosen_emotion and chosen_emotion in emotion_map:
|
||||
return random.choice(emotion_map[chosen_emotion]), chosen_emotion
|
||||
return random.choice(list(sampled_emojis)), ""
|
||||
|
||||
|
||||
async def select_emoji_for_maisaka(
|
||||
*,
|
||||
requested_emotion: str = "",
|
||||
reasoning: str = "",
|
||||
context_texts: Sequence[str] | None = None,
|
||||
sample_size: int = 30,
|
||||
) -> tuple[MaiEmoji | None, str]:
|
||||
"""为 Maisaka 选择一个合适的表情。"""
|
||||
|
||||
available_emojis = list(emoji_manager.emojis)
|
||||
if not available_emojis:
|
||||
return None, ""
|
||||
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
if normalized_requested_emotion:
|
||||
matched_emojis = [
|
||||
emoji
|
||||
for emoji in available_emojis
|
||||
if normalized_requested_emotion.lower() in (emotion.lower() for emotion in _normalize_emotions(emoji))
|
||||
]
|
||||
if matched_emojis:
|
||||
return random.choice(matched_emojis), normalized_requested_emotion
|
||||
|
||||
sampled_emojis = random.sample(
|
||||
available_emojis,
|
||||
min(max(sample_size, 1), len(available_emojis)),
|
||||
)
|
||||
context_text = _build_recent_context_text(context_texts or [])
|
||||
return await _select_emoji_with_llm(
|
||||
sampled_emojis=sampled_emojis,
|
||||
reasoning=reasoning,
|
||||
context_text=context_text,
|
||||
)
|
||||
|
||||
|
||||
async def send_emoji_for_maisaka(
|
||||
*,
|
||||
stream_id: str,
|
||||
requested_emotion: str = "",
|
||||
reasoning: str = "",
|
||||
context_texts: Sequence[str] | None = None,
|
||||
) -> MaisakaEmojiSendResult:
|
||||
"""为 Maisaka 选择并发送一个表情。"""
|
||||
|
||||
normalized_requested_emotion = requested_emotion.strip()
|
||||
normalized_reasoning = reasoning.strip()
|
||||
normalized_context_texts = _normalize_context_texts(context_texts)
|
||||
sample_size = 30
|
||||
|
||||
before_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.before_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
abort_message="表情选择已被 Hook 中止。",
|
||||
)
|
||||
if before_select_result.aborted:
|
||||
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情选择已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
)
|
||||
|
||||
before_select_kwargs = before_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
|
||||
if isinstance(before_select_kwargs.get("context_texts"), list):
|
||||
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
|
||||
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
|
||||
|
||||
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=normalized_context_texts,
|
||||
sample_size=sample_size,
|
||||
)
|
||||
after_select_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.maisaka.after_select",
|
||||
stream_id=stream_id,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
reasoning=normalized_reasoning,
|
||||
context_texts=list(normalized_context_texts),
|
||||
sample_size=sample_size,
|
||||
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
|
||||
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
|
||||
matched_emotion=matched_emotion,
|
||||
abort_message="表情发送已被 Hook 中止。",
|
||||
)
|
||||
if after_select_result.aborted:
|
||||
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=abort_message or "表情发送已被 Hook 中止。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
after_select_kwargs = after_select_result.kwargs
|
||||
normalized_requested_emotion = str(
|
||||
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
|
||||
).strip()
|
||||
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
|
||||
if override_emoji is None:
|
||||
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
|
||||
if override_emoji is not None:
|
||||
selected_emoji = override_emoji
|
||||
|
||||
if selected_emoji is None:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="当前表情包库中没有可用表情。",
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
|
||||
if not emoji_base64:
|
||||
raise ValueError("表情图片转换为 base64 失败")
|
||||
except Exception as exc:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=f"发送表情包失败:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
try:
|
||||
target_session = chat_manager.get_session_by_session_id(stream_id)
|
||||
if target_session is not None and target_session.platform == CLI_PLATFORM_NAME:
|
||||
preview_message = (
|
||||
f"已发送表情包:{selected_emoji.description.strip()}"
|
||||
if selected_emoji.description.strip()
|
||||
else "[表情包]"
|
||||
)
|
||||
render_cli_message(preview_message)
|
||||
sent = True
|
||||
else:
|
||||
sent = await send_service.emoji_to_stream(
|
||||
emoji_base64=emoji_base64,
|
||||
stream_id=stream_id,
|
||||
storage_message=True,
|
||||
set_reply=False,
|
||||
reply_message=None,
|
||||
)
|
||||
except Exception as exc:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message=f"发送表情包时发生异常:{exc}",
|
||||
description=selected_emoji.description.strip(),
|
||||
emotions=_normalize_emotions(selected_emoji),
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
description = selected_emoji.description.strip()
|
||||
emotions = _normalize_emotions(selected_emoji)
|
||||
if not sent:
|
||||
return MaisakaEmojiSendResult(
|
||||
success=False,
|
||||
message="发送表情包失败。",
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
|
||||
emoji_manager.update_emoji_usage(selected_emoji)
|
||||
success_message = (
|
||||
f"已发送表情包:{description}(情绪:{', '.join(emotions)})"
|
||||
if emotions
|
||||
else f"已发送表情包:{description}"
|
||||
)
|
||||
return MaisakaEmojiSendResult(
|
||||
success=True,
|
||||
message=success_message,
|
||||
emoji_base64=emoji_base64,
|
||||
description=description,
|
||||
emotions=emotions,
|
||||
requested_emotion=normalized_requested_emotion,
|
||||
matched_emotion=matched_emotion,
|
||||
)
|
||||
@@ -44,6 +44,12 @@ class ImageManager:
|
||||
|
||||
logger.info("图片管理器初始化完成")
|
||||
|
||||
def _get_image_record(self, image_hash: str) -> Optional[Images]:
|
||||
"""根据哈希获取图片记录。"""
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||
return session.exec(statement).first()
|
||||
|
||||
async def get_image_description(
|
||||
self,
|
||||
*,
|
||||
@@ -76,9 +82,8 @@ class ImageManager:
|
||||
hash_str = hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
if record := self._get_image_record(hash_str):
|
||||
if record.vlm_processed and record.description:
|
||||
return record.description
|
||||
except Exception as e:
|
||||
logger.error(f"查询图片描述时发生错误: {e}")
|
||||
@@ -86,12 +91,17 @@ class ImageManager:
|
||||
if not image_bytes:
|
||||
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
|
||||
return ""
|
||||
try:
|
||||
await self.ensure_image_saved(image_bytes)
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件时发生错误: {e}")
|
||||
return ""
|
||||
if not wait_for_build:
|
||||
self._schedule_description_build(hash_str, image_bytes)
|
||||
return ""
|
||||
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
|
||||
try:
|
||||
image = await self.save_image_and_process(image_bytes)
|
||||
image = await self.build_image_description(image_bytes)
|
||||
return image.description
|
||||
except Exception as e:
|
||||
logger.error(f"生成图片描述时发生错误: {e}")
|
||||
@@ -120,7 +130,7 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}")
|
||||
await self.save_image_and_process(image_bytes)
|
||||
await self.build_image_description(image_bytes)
|
||||
logger.info(f"图片描述后台构建完成,哈希值: {image_hash}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}")
|
||||
@@ -201,6 +211,7 @@ class ImageManager:
|
||||
return False
|
||||
record.description = image.description
|
||||
record.last_used_time = datetime.now()
|
||||
record.vlm_processed = image.vlm_processed
|
||||
session.add(record)
|
||||
logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}")
|
||||
except Exception as e:
|
||||
@@ -239,22 +250,13 @@ class ImageManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
|
||||
"""
|
||||
保存图片并生成描述
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节数据
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
Exception: 如果在保存或处理过程中发生错误
|
||||
"""
|
||||
async def ensure_image_saved(self, image_bytes: bytes) -> MaiImage:
|
||||
"""先保存图片记录,确保后续可以按哈希回填图片内容。"""
|
||||
hash_str = hashlib.sha256(image_bytes).hexdigest()
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=hash_str).limit(1)
|
||||
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
|
||||
if record := session.exec(statement).first():
|
||||
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
|
||||
record.last_used_time = datetime.now()
|
||||
@@ -270,18 +272,38 @@ class ImageManager:
|
||||
tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp"
|
||||
with tmp_file_path.open("wb") as f:
|
||||
f.write(image_bytes)
|
||||
mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes)
|
||||
mai_image = MaiImage(full_path=tmp_file_path, image_bytes=image_bytes)
|
||||
await mai_image.calculate_hash_format()
|
||||
if not self.register_image_to_db(mai_image):
|
||||
raise RuntimeError(f"保存图片记录到数据库失败: {hash_str}")
|
||||
return mai_image
|
||||
|
||||
async def build_image_description(self, image_bytes: bytes) -> MaiImage:
|
||||
"""在图片已保存的前提下生成或补齐图片描述。"""
|
||||
mai_image = await self.ensure_image_saved(image_bytes)
|
||||
if mai_image.vlm_processed and mai_image.description:
|
||||
return mai_image
|
||||
|
||||
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
|
||||
mai_image.description = desc
|
||||
mai_image.vlm_processed = True
|
||||
try:
|
||||
self.register_image_to_db(mai_image)
|
||||
except Exception as e:
|
||||
logger.error(f"保存新图片记录到数据库时发生错误: {e}")
|
||||
raise e
|
||||
if not self.update_image_description(mai_image):
|
||||
raise RuntimeError(f"更新图片描述失败: {mai_image.file_hash}")
|
||||
return mai_image
|
||||
|
||||
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
|
||||
"""
|
||||
保存图片并生成描述
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): 图片的字节数据
|
||||
Returns:
|
||||
return (MaiImage): 包含图片信息的 MaiImage 对象
|
||||
Raises:
|
||||
Exception: 如果在保存或处理过程中发生错误
|
||||
"""
|
||||
return await self.build_image_description(image_bytes)
|
||||
|
||||
def cleanup_invalid_descriptions_in_db(self):
|
||||
"""
|
||||
清理数据库中无效的图片记录
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""聊天消息入口与主链路调度。"""
|
||||
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import os
|
||||
import traceback
|
||||
@@ -13,12 +14,15 @@ from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||
|
||||
from src.core.announcement_manager import global_announcement_manager
|
||||
from src.plugin_runtime.component_query import component_query_service
|
||||
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
|
||||
from src.plugin_runtime.hook_schema_utils import build_object_schema
|
||||
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
|
||||
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
|
||||
|
||||
from .message import SessionMessage
|
||||
from .chat_manager import chat_manager
|
||||
from .message import SessionMessage
|
||||
|
||||
# 定义日志配置
|
||||
|
||||
@@ -29,7 +33,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
"""注册聊天消息主链内置 Hook 规格。
|
||||
|
||||
Args:
|
||||
registry: 目标 Hook 规格注册中心。
|
||||
|
||||
Returns:
|
||||
List[HookSpec]: 实际注册的 Hook 规格列表。
|
||||
"""
|
||||
|
||||
return registry.register_hook_specs(
|
||||
[
|
||||
HookSpec(
|
||||
name="chat.receive.before_process",
|
||||
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前入站消息的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.receive.after_process",
|
||||
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "已完成 `process()` 的序列化 SessionMessage。",
|
||||
},
|
||||
},
|
||||
required=["message"],
|
||||
),
|
||||
default_timeout_ms=8000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.before_execute",
|
||||
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命中的命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
},
|
||||
required=["message", "command_name", "plugin_id", "matched_groups"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="chat.command.after_execute",
|
||||
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "当前命令消息的序列化 SessionMessage。",
|
||||
},
|
||||
"command_name": {
|
||||
"type": "string",
|
||||
"description": "命令名称。",
|
||||
},
|
||||
"plugin_id": {
|
||||
"type": "string",
|
||||
"description": "命令所属插件 ID。",
|
||||
},
|
||||
"matched_groups": {
|
||||
"type": "object",
|
||||
"description": "命令正则命名捕获结果。",
|
||||
},
|
||||
"success": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行是否成功。",
|
||||
},
|
||||
"response": {
|
||||
"type": "string",
|
||||
"description": "命令返回文本。",
|
||||
},
|
||||
"intercept_message_level": {
|
||||
"type": "integer",
|
||||
"description": "命令拦截等级。",
|
||||
},
|
||||
"continue_process": {
|
||||
"type": "boolean",
|
||||
"description": "命令执行后是否继续后续消息处理。",
|
||||
},
|
||||
},
|
||||
required=[
|
||||
"message",
|
||||
"command_name",
|
||||
"plugin_id",
|
||||
"matched_groups",
|
||||
"success",
|
||||
"intercept_message_level",
|
||||
"continue_process",
|
||||
],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=False,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ChatBot:
|
||||
"""聊天机器人入口协调器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化聊天机器人入口。"""
|
||||
|
||||
@@ -44,6 +178,66 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
"""获取插件运行时管理器。
|
||||
|
||||
Returns:
|
||||
Any: 插件运行时管理器单例。
|
||||
"""
|
||||
|
||||
from src.plugin_runtime.integration import get_plugin_runtime_manager
|
||||
|
||||
return get_plugin_runtime_manager()
|
||||
|
||||
@staticmethod
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
"""将任意值安全转换为整数。
|
||||
|
||||
Args:
|
||||
value: 待转换的值。
|
||||
default: 转换失败时的默认值。
|
||||
|
||||
Returns:
|
||||
int: 转换后的整数结果。
|
||||
"""
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
async def _invoke_message_hook(
|
||||
self,
|
||||
hook_name: str,
|
||||
message: SessionMessage,
|
||||
**kwargs: Any,
|
||||
) -> tuple[HookDispatchResult, SessionMessage]:
|
||||
"""触发携带会话消息的命名 Hook。
|
||||
|
||||
Args:
|
||||
hook_name: 目标 Hook 名称。
|
||||
message: 当前会话消息。
|
||||
**kwargs: 需要附带传递的额外参数。
|
||||
|
||||
Returns:
|
||||
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
|
||||
"""
|
||||
|
||||
hook_result = await self._get_runtime_manager().invoke_hook(
|
||||
hook_name,
|
||||
message=serialize_session_message(message),
|
||||
**kwargs,
|
||||
)
|
||||
mutated_message = message
|
||||
raw_message = hook_result.kwargs.get("message")
|
||||
if raw_message is not None:
|
||||
try:
|
||||
mutated_message = deserialize_session_message(raw_message)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
|
||||
return hook_result, mutated_message
|
||||
|
||||
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
|
||||
"""使用统一组件注册表处理命令。
|
||||
|
||||
@@ -71,6 +265,25 @@ class ChatBot:
|
||||
return False, None, True
|
||||
|
||||
message.is_command = True
|
||||
before_result, message = await self._invoke_message_hook(
|
||||
"chat.command.before_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
)
|
||||
if before_result.aborted:
|
||||
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
|
||||
return True, None, False
|
||||
|
||||
hook_kwargs = before_result.kwargs
|
||||
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
|
||||
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
|
||||
matched_groups = (
|
||||
dict(hook_kwargs["matched_groups"])
|
||||
if isinstance(hook_kwargs.get("matched_groups"), dict)
|
||||
else dict(matched_groups)
|
||||
)
|
||||
|
||||
# 获取插件配置
|
||||
plugin_config = component_query_service.get_plugin_config(plugin_name)
|
||||
@@ -82,27 +295,43 @@ class ChatBot:
|
||||
plugin_config=plugin_config,
|
||||
matched_groups=matched_groups,
|
||||
)
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
# 记录命令执行结果
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
# 根据命令的拦截设置决定是否继续处理消息
|
||||
return (
|
||||
True,
|
||||
response,
|
||||
not bool(intercept_message_level),
|
||||
) # 找到命令,根据intercept_message决定是否继续
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令时出错: {command_name} - {e}")
|
||||
continue_process = not bool(intercept_message_level)
|
||||
except Exception as exc:
|
||||
logger.error(f"执行命令时出错: {command_name} - {exc}")
|
||||
logger.error(traceback.format_exc())
|
||||
success = False
|
||||
response = str(exc)
|
||||
intercept_message_level = 1
|
||||
continue_process = False
|
||||
|
||||
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
|
||||
return True, str(e), False # 出错时继续处理消息
|
||||
after_result, message = await self._invoke_message_hook(
|
||||
"chat.command.after_execute",
|
||||
message,
|
||||
command_name=command_name,
|
||||
plugin_id=plugin_name,
|
||||
matched_groups=dict(matched_groups),
|
||||
success=success,
|
||||
response=response,
|
||||
intercept_message_level=intercept_message_level,
|
||||
continue_process=continue_process,
|
||||
)
|
||||
after_kwargs = after_result.kwargs
|
||||
success = bool(after_kwargs.get("success", success))
|
||||
raw_response = after_kwargs.get("response", response)
|
||||
response = None if raw_response is None else str(raw_response)
|
||||
intercept_message_level = self._coerce_int(
|
||||
after_kwargs.get("intercept_message_level", intercept_message_level),
|
||||
intercept_message_level,
|
||||
)
|
||||
continue_process = bool(after_kwargs.get("continue_process", continue_process))
|
||||
self._mark_command_message(message, intercept_message_level)
|
||||
|
||||
if success:
|
||||
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
|
||||
else:
|
||||
logger.warning(f"命令执行失败: {command_name} - {response}")
|
||||
|
||||
return True, response, continue_process
|
||||
|
||||
return False, None, True
|
||||
|
||||
@@ -138,6 +367,17 @@ class ChatBot:
|
||||
cmd_result: Optional[str],
|
||||
continue_process: bool,
|
||||
) -> bool:
|
||||
"""处理命令链结果并决定是否终止主消息链。
|
||||
|
||||
Args:
|
||||
message: 当前命令消息。
|
||||
cmd_result: 命令响应文本。
|
||||
continue_process: 是否继续后续主链处理。
|
||||
|
||||
Returns:
|
||||
bool: ``True`` 表示已经终止后续主链。
|
||||
"""
|
||||
|
||||
if continue_process:
|
||||
return False
|
||||
|
||||
@@ -145,9 +385,18 @@ class ChatBot:
|
||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return True
|
||||
|
||||
async def handle_notice_message(self, message: SessionMessage):
|
||||
async def handle_notice_message(self, message: SessionMessage) -> bool:
|
||||
"""处理通知类消息。
|
||||
|
||||
Args:
|
||||
message: 当前通知消息。
|
||||
|
||||
Returns:
|
||||
bool: 当前消息是否为通知消息。
|
||||
"""
|
||||
|
||||
if message.message_id != "notice":
|
||||
return
|
||||
return False
|
||||
|
||||
message.is_notify = True
|
||||
logger.debug("notice消息")
|
||||
@@ -203,9 +452,12 @@ class ChatBot:
|
||||
return True
|
||||
|
||||
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
|
||||
"""处理消息回送 ID 对应关系。
|
||||
|
||||
Args:
|
||||
raw_data: 平台适配器上报的原始回送载荷。
|
||||
"""
|
||||
用于专门处理回送消息ID的函数
|
||||
"""
|
||||
|
||||
message_data: Dict[str, Any] = raw_data.get("content", {})
|
||||
if not message_data:
|
||||
return
|
||||
@@ -218,18 +470,10 @@ class ChatBot:
|
||||
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
heart_flow模式:使用思维流系统进行回复
|
||||
- 包含思维流状态管理
|
||||
- 在回复前进行观察和状态更新
|
||||
- 回复后更新思维流状态
|
||||
- 消息过滤
|
||||
- 记忆激活
|
||||
- 意愿计算
|
||||
- 消息生成和发送
|
||||
- 表情包处理
|
||||
- 性能计时
|
||||
"""处理统一格式的入站消息字典。
|
||||
|
||||
Args:
|
||||
message_data: 适配器整理后的统一消息字典。
|
||||
"""
|
||||
try:
|
||||
# 确保所有任务已启动
|
||||
@@ -253,7 +497,13 @@ class ChatBot:
|
||||
logger.error(f"预处理消息失败: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
async def receive_message(self, message: SessionMessage):
|
||||
async def receive_message(self, message: SessionMessage) -> None:
|
||||
"""处理单条入站会话消息。
|
||||
|
||||
Args:
|
||||
message: 待处理的会话消息。
|
||||
"""
|
||||
|
||||
try:
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -272,6 +522,19 @@ class ChatBot:
|
||||
)
|
||||
|
||||
message.session_id = session_id # 正确初始化session_id
|
||||
before_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.before_process",
|
||||
message,
|
||||
)
|
||||
if before_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
additional_config = message.message_info.additional_config
|
||||
if isinstance(additional_config, dict):
|
||||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||||
|
||||
# TODO: 修复事件预处理部分
|
||||
# continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
@@ -286,14 +549,24 @@ class ChatBot:
|
||||
# if await self.handle_notice_message(message):
|
||||
# pass
|
||||
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
|
||||
if global_config.maisaka.direct_image_input:
|
||||
message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined]
|
||||
# 处理消息内容,识别表情包等二进制数据并转化为文本描述。
|
||||
# 如果 Maisaka 需要直接消费图片,会在后续构建 prompt 时按需回填图片二进制数据,
|
||||
# 这里不再复制整条原始消息。
|
||||
# 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。
|
||||
await message.process(
|
||||
enable_heavy_media_analysis=False,
|
||||
enable_voice_transcription=False,
|
||||
)
|
||||
after_process_result, message = await self._invoke_message_hook(
|
||||
"chat.receive.after_process",
|
||||
message,
|
||||
)
|
||||
if after_process_result.aborted:
|
||||
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
|
||||
return
|
||||
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
from asyncio import Task
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
@@ -36,6 +35,102 @@ class MsgIDMapping:
|
||||
|
||||
|
||||
class SessionMessage(MaiMessage):
|
||||
|
||||
#便于调试的打印函数
|
||||
def __str__(self) -> str:
|
||||
"""返回适合日志输出的消息摘要。"""
|
||||
return self.to_debug_string()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""返回适合调试场景的消息摘要。"""
|
||||
return self.to_debug_string()
|
||||
|
||||
def to_debug_string(self) -> str:
|
||||
"""构建包含引用信息的调试字符串。
|
||||
|
||||
Returns:
|
||||
str: 适合记录日志的消息摘要。
|
||||
"""
|
||||
user_info = self.message_info.user_info
|
||||
group_info = self.message_info.group_info
|
||||
chat_type = "group" if group_info else "private"
|
||||
group_id = group_info.group_id if group_info else None
|
||||
group_name = group_info.group_name if group_info else None
|
||||
component_summaries = [self._summarize_component(component) for component in self.raw_message.components]
|
||||
raw_components = ", ".join(component_summaries) if component_summaries else "empty"
|
||||
|
||||
return (
|
||||
"SessionMessage("
|
||||
f"message_id={self.message_id!r}, "
|
||||
f"platform={self.platform!r}, "
|
||||
f"chat_type={chat_type!r}, "
|
||||
f"group_id={group_id!r}, "
|
||||
f"group_name={group_name!r}, "
|
||||
f"user_id={user_info.user_id!r}, "
|
||||
f"user_nickname={user_info.user_nickname!r}, "
|
||||
f"user_cardname={user_info.user_cardname!r}, "
|
||||
f"reply_to={self.reply_to!r}, "
|
||||
f"processed_plain_text={self._truncate_text(self.processed_plain_text)}, "
|
||||
f"raw_components=[{raw_components}]"
|
||||
")"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_text(text: str | None, max_length: int = 120) -> str:
|
||||
"""截断较长文本,避免日志过长。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最大保留长度。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本表示。
|
||||
"""
|
||||
if text is None:
|
||||
return "None"
|
||||
normalized_text = text.replace("\r", "\\r").replace("\n", "\\n")
|
||||
if len(normalized_text) <= max_length:
|
||||
return repr(normalized_text)
|
||||
return repr(f"{normalized_text[:max_length]}...")
|
||||
|
||||
def _summarize_component(self, component: StandardMessageComponents) -> str:
|
||||
"""生成单个消息组件的调试摘要。
|
||||
|
||||
Args:
|
||||
component: 消息组件对象。
|
||||
|
||||
Returns:
|
||||
str: 组件摘要文本。
|
||||
"""
|
||||
if isinstance(component, TextComponent):
|
||||
return f"Text(text={self._truncate_text(component.text, 80)})"
|
||||
if isinstance(component, ImageComponent):
|
||||
return f"Image(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, EmojiComponent):
|
||||
return f"Emoji(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, AtComponent):
|
||||
target_name = component.target_user_cardname or component.target_user_nickname or component.target_user_id
|
||||
return f"At(target={target_name!r})"
|
||||
if isinstance(component, VoiceComponent):
|
||||
return f"Voice(content={self._truncate_text(component.content or None, 60)})"
|
||||
if isinstance(component, ReplyComponent):
|
||||
sender_name = (
|
||||
component.target_message_sender_cardname
|
||||
or component.target_message_sender_nickname
|
||||
or component.target_message_sender_id
|
||||
)
|
||||
return (
|
||||
"Reply("
|
||||
f"target_message_id={component.target_message_id!r}, "
|
||||
f"target_sender={sender_name!r}, "
|
||||
f"target_content={self._truncate_text(component.target_message_content, 80)}"
|
||||
")"
|
||||
)
|
||||
if isinstance(component, ForwardNodeComponent):
|
||||
return f"ForwardNode(count={len(component.forward_components)})"
|
||||
return f"{component.__class__.__name__}"
|
||||
#便于调试的打印函数end
|
||||
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -18,28 +18,29 @@ install(extra_lines=3)
|
||||
logger = get_logger("sender")
|
||||
|
||||
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
|
||||
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
|
||||
|
||||
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
|
||||
# TODO: 重构完成后完成webui相关
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
|
||||
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""获取 WebUI 聊天室广播器。
|
||||
|
||||
Returns:
|
||||
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 二元组;
|
||||
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 三元组;
|
||||
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
|
||||
"""
|
||||
global _webui_chat_broadcaster
|
||||
if _webui_chat_broadcaster is None:
|
||||
try:
|
||||
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
|
||||
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
|
||||
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
|
||||
except ImportError:
|
||||
_webui_chat_broadcaster = (None, None)
|
||||
_webui_chat_broadcaster = (None, None, None)
|
||||
return _webui_chat_broadcaster
|
||||
|
||||
|
||||
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
|
||||
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||
|
||||
if is_webui_message and chat_manager is not None:
|
||||
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
message_type = "rich"
|
||||
segments = message_segments
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
await chat_manager.broadcast_to_group(
|
||||
group_id=group_id or default_group_id or "",
|
||||
message={
|
||||
"type": "bot_message",
|
||||
"content": message.processed_plain_text,
|
||||
"message_type": message_type,
|
||||
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
|
||||
"avatar": None,
|
||||
"is_bot": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||
|
||||
@@ -35,7 +35,7 @@ from src.services import llm_service as llm_api
|
||||
|
||||
from src.chat.logger.plan_reply_logger import PlanReplyLogger
|
||||
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
|
||||
from src.learners.jargon_explainer_old import explain_jargon_in_context
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
init_memory_retrieval_sys()
|
||||
@@ -688,39 +688,41 @@ class DefaultReplyer:
|
||||
return None
|
||||
|
||||
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
|
||||
"""
|
||||
根据聊天流ID获取匹配的额外prompt(仅匹配group类型)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID(哈希值)
|
||||
|
||||
Returns:
|
||||
str: 匹配的额外prompt内容,如果没有匹配则返回空字符串
|
||||
"""
|
||||
if not global_config.experimental.chat_prompts:
|
||||
"""根据聊天流 ID 获取匹配的额外 prompt。"""
|
||||
if not global_config.chat.chat_prompts:
|
||||
return ""
|
||||
|
||||
for chat_prompt_str in global_config.experimental.chat_prompts:
|
||||
if not isinstance(chat_prompt_str, str):
|
||||
for chat_prompt_item in global_config.chat.chat_prompts:
|
||||
if hasattr(chat_prompt_item, "rule_type") and hasattr(chat_prompt_item, "prompt"):
|
||||
if str(chat_prompt_item.rule_type or "").strip() != "group":
|
||||
continue
|
||||
|
||||
config_chat_id = self._build_chat_uid(
|
||||
str(chat_prompt_item.platform or "").strip(),
|
||||
str(chat_prompt_item.item_id or "").strip(),
|
||||
True,
|
||||
)
|
||||
prompt_content = str(chat_prompt_item.prompt or "").strip()
|
||||
if config_chat_id == chat_id and prompt_content:
|
||||
logger.debug(f"匹配到群聊 prompt 配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
continue
|
||||
|
||||
# 解析配置字符串,检查类型是否为group
|
||||
parts = chat_prompt_str.split(":", 3)
|
||||
if len(parts) != 4:
|
||||
if not isinstance(chat_prompt_item, str):
|
||||
continue
|
||||
|
||||
stream_type = parts[2]
|
||||
# 只匹配group类型
|
||||
if stream_type != "group":
|
||||
# 兼容旧格式的 platform:id:type:prompt 配置字符串。
|
||||
parts = chat_prompt_item.split(":", 3)
|
||||
if len(parts) != 4 or parts[2] != "group":
|
||||
continue
|
||||
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
|
||||
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_item)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
config_chat_id, prompt_content = result
|
||||
if config_chat_id == chat_id:
|
||||
logger.debug(f"匹配到群聊prompt配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
logger.debug(f"匹配到群聊 prompt 配置,chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
|
||||
return prompt_content
|
||||
|
||||
return ""
|
||||
|
||||
453
src/chat/replyer/maisaka_generator_multi.py
Normal file
453
src/chat/replyer/maisaka_generator_multi.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import random
|
||||
import time
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.reply_generation_data_models import (
|
||||
GenerationMetrics,
|
||||
LLMCompletionResult,
|
||||
ReplyGenerationResult,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.common.prompt_i18n import load_prompt
|
||||
from src.config.config import global_config
|
||||
from src.core.types import ActionInfo
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
|
||||
from src.maisaka.message_adapter import parse_speaker_content
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
"""Maisaka replyer 使用的回复上下文。"""
|
||||
|
||||
expression_habits: str = ""
|
||||
selected_expression_ids: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExpressionRecord:
|
||||
"""表达方式的轻量记录。"""
|
||||
|
||||
expression_id: Optional[int]
|
||||
situation: str
|
||||
style: str
|
||||
|
||||
|
||||
class MaisakaReplyGenerator:
|
||||
"""生成 Maisaka 的最终可见回复。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
request_type: str = "maisaka_replyer",
|
||||
) -> None:
|
||||
self.chat_stream = chat_stream
|
||||
self.request_type = request_type
|
||||
self.express_model = LLMServiceClient(
|
||||
task_name="replyer",
|
||||
request_type=request_type,
|
||||
)
|
||||
self._personality_prompt = self._build_personality_prompt()
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构建 replyer 使用的人设描述。"""
|
||||
try:
|
||||
bot_name = global_config.bot.nickname
|
||||
alias_names = global_config.bot.alias_names
|
||||
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
|
||||
|
||||
prompt_personality = global_config.personality.personality
|
||||
if (
|
||||
hasattr(global_config.personality, "states")
|
||||
and global_config.personality.states
|
||||
and hasattr(global_config.personality, "state_probability")
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
|
||||
except Exception as exc:
|
||||
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
|
||||
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
|
||||
|
||||
@staticmethod
|
||||
def _normalize_content(content: str, limit: int = 500) -> str:
|
||||
normalized = " ".join((content or "").split())
|
||||
if len(normalized) > limit:
|
||||
return normalized[:limit] + "..."
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
|
||||
del message
|
||||
return ""
|
||||
|
||||
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
|
||||
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
if speaker_name == bot_nickname:
|
||||
return self._normalize_content(body.strip())
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
|
||||
"""按说话人拆分用户消息。"""
|
||||
segments: List[tuple[Optional[str], str]] = []
|
||||
current_speaker: Optional[str] = None
|
||||
current_lines: List[str] = []
|
||||
|
||||
for raw_line in raw_content.splitlines():
|
||||
speaker_name, content_body = parse_speaker_content(raw_line)
|
||||
if speaker_name is not None:
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
current_speaker = speaker_name
|
||||
current_lines = [content_body]
|
||||
continue
|
||||
|
||||
current_lines.append(raw_line)
|
||||
|
||||
if current_lines:
|
||||
segments.append((current_speaker, "\n".join(current_lines)))
|
||||
|
||||
return segments
|
||||
|
||||
def _build_system_prompt(
|
||||
self,
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> str:
|
||||
"""构建 Maisaka replyer 使用的系统提示词。"""
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
try:
|
||||
system_prompt = load_prompt(
|
||||
"maisaka_replyer",
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_block=f"当前时间:{current_time}",
|
||||
identity=self._personality_prompt,
|
||||
reply_style=global_config.personality.reply_style,
|
||||
)
|
||||
except Exception:
|
||||
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
|
||||
|
||||
extra_sections: List[str] = []
|
||||
if expression_habits.strip():
|
||||
extra_sections.append(expression_habits.strip())
|
||||
if reply_reason.strip():
|
||||
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
|
||||
|
||||
if not extra_sections:
|
||||
return system_prompt
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
"""构建追加在上下文末尾的回复指令。"""
|
||||
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
|
||||
|
||||
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
|
||||
"""将 replyer 上下文拆成多条 LLM 消息。"""
|
||||
bot_nickname = global_config.bot.nickname.strip() or "Bot"
|
||||
default_user_name = global_config.maisaka.user_name.strip() or "User"
|
||||
messages: List[Message] = []
|
||||
|
||||
for message in chat_history:
|
||||
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
|
||||
continue
|
||||
|
||||
if isinstance(message, SessionBackedMessage):
|
||||
guided_reply = self._extract_guided_bot_reply(message)
|
||||
if guided_reply:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(guided_reply).build()
|
||||
)
|
||||
continue
|
||||
|
||||
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
|
||||
content = self._normalize_content(content_body)
|
||||
if not content:
|
||||
continue
|
||||
|
||||
visible_speaker = speaker_name or default_user_name
|
||||
if visible_speaker == bot_nickname:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
|
||||
)
|
||||
continue
|
||||
|
||||
user_content = f"[{visible_speaker}]{content}"
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
|
||||
continue
|
||||
|
||||
if isinstance(message, AssistantMessage):
|
||||
visible_reply = self._extract_visible_assistant_reply(message)
|
||||
if visible_reply:
|
||||
messages.append(
|
||||
MessageBuilder().set_role(RoleType.Assistant).add_text_content(visible_reply).build()
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _build_request_messages(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_reason: str,
|
||||
expression_habits: str = "",
|
||||
) -> List[Message]:
|
||||
"""构建发给大模型的消息列表。"""
|
||||
messages: List[Message] = []
|
||||
system_prompt = self._build_system_prompt(
|
||||
reply_reason=reply_reason,
|
||||
expression_habits=expression_habits,
|
||||
)
|
||||
instruction = self._build_reply_instruction()
|
||||
|
||||
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
|
||||
messages.extend(self._build_history_messages(chat_history))
|
||||
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _build_request_prompt_preview(messages: List[Message]) -> str:
|
||||
"""将消息列表转为便于调试的文本预览。"""
|
||||
preview_lines: List[str] = []
|
||||
for message in messages:
|
||||
role_name = message.role.value.capitalize()
|
||||
preview_lines.append(f"{role_name}: {message.get_text_content()}")
|
||||
return "\n\n".join(preview_lines)
|
||||
|
||||
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
|
||||
"""解析当前回复使用的会话 ID。"""
|
||||
if stream_id:
|
||||
return stream_id
|
||||
if self.chat_stream is not None:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
stream_id: Optional[str],
|
||||
) -> MaisakaReplyContext:
|
||||
"""在 replyer 内部构建表达习惯和黑话解释。"""
|
||||
session_id = self._resolve_session_id(stream_id)
|
||||
if not session_id:
|
||||
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
|
||||
return MaisakaReplyContext()
|
||||
|
||||
expression_habits, selected_expression_ids = self._build_expression_habits(
|
||||
session_id=session_id,
|
||||
chat_history=chat_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason,
|
||||
)
|
||||
return MaisakaReplyContext(
|
||||
expression_habits=expression_habits,
|
||||
selected_expression_ids=selected_expression_ids,
|
||||
)
|
||||
|
||||
def _build_expression_habits(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_history: List[LLMContextMessage],
|
||||
reply_message: Optional[SessionMessage],
|
||||
reply_reason: str,
|
||||
) -> tuple[str, List[int]]:
|
||||
"""查询并格式化适合当前会话的表达习惯。"""
|
||||
del chat_history
|
||||
del reply_message
|
||||
del reply_reason
|
||||
|
||||
expression_records = self._load_expression_records(session_id)
|
||||
if not expression_records:
|
||||
return "", []
|
||||
|
||||
lines: List[str] = []
|
||||
selected_ids: List[int] = []
|
||||
for expression in expression_records:
|
||||
if expression.expression_id is not None:
|
||||
selected_ids.append(expression.expression_id)
|
||||
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
|
||||
|
||||
block = "【表达习惯参考】\n" + "\n".join(lines)
|
||||
logger.info(
|
||||
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
|
||||
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
|
||||
)
|
||||
return block, selected_ids
|
||||
|
||||
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
|
||||
"""提取表达方式静态数据,避免 detached ORM 对象。"""
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
|
||||
if global_config.expression.expression_checked_only:
|
||||
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
|
||||
|
||||
query = query.where(
|
||||
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
|
||||
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
|
||||
|
||||
expressions = session.exec(query.limit(5)).all()
|
||||
return [
|
||||
_ExpressionRecord(
|
||||
expression_id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
)
|
||||
for expression in expressions
|
||||
]
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
extra_info: str = "",
|
||||
reply_reason: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
chosen_actions: Optional[List[object]] = None,
|
||||
from_plugin: bool = True,
|
||||
stream_id: Optional[str] = None,
|
||||
reply_message: Optional[SessionMessage] = None,
|
||||
reply_time_point: Optional[float] = None,
|
||||
think_level: int = 1,
|
||||
unknown_words: Optional[List[str]] = None,
|
||||
log_reply: bool = True,
|
||||
chat_history: Optional[List[LLMContextMessage]] = None,
|
||||
expression_habits: str = "",
|
||||
selected_expression_ids: Optional[List[int]] = None,
|
||||
) -> Tuple[bool, ReplyGenerationResult]:
|
||||
"""结合上下文生成 Maisaka 的最终可见回复。"""
|
||||
del available_actions
|
||||
del chosen_actions
|
||||
del extra_info
|
||||
del from_plugin
|
||||
del log_reply
|
||||
del reply_time_point
|
||||
del think_level
|
||||
del unknown_words
|
||||
|
||||
result = ReplyGenerationResult()
|
||||
if chat_history is None:
|
||||
result.error_message = "聊天历史为空"
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
|
||||
f"历史消息数={len(chat_history)} 目标消息编号="
|
||||
f"{reply_message.message_id if reply_message else None}"
|
||||
)
|
||||
|
||||
filtered_history = [
|
||||
message
|
||||
for message in chat_history
|
||||
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
|
||||
]
|
||||
|
||||
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
|
||||
|
||||
# Validate that express_model is properly initialized
|
||||
if self.express_model is None:
|
||||
logger.error("Maisaka 回复器的回复模型未初始化")
|
||||
result.error_message = "回复模型尚未初始化"
|
||||
return False, result
|
||||
|
||||
try:
|
||||
reply_context = await self._build_reply_context(
|
||||
chat_history=filtered_history,
|
||||
reply_message=reply_message,
|
||||
reply_reason=reply_reason or "",
|
||||
stream_id=stream_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建回复上下文失败: {exc}"
|
||||
return False, result
|
||||
|
||||
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
|
||||
result.selected_expression_ids = (
|
||||
list(selected_expression_ids)
|
||||
if selected_expression_ids is not None
|
||||
else list(reply_context.selected_expression_ids)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
|
||||
try:
|
||||
request_messages = self._build_request_messages(
|
||||
chat_history=filtered_history,
|
||||
reply_reason=reply_reason or "",
|
||||
expression_habits=merged_expression_habits,
|
||||
)
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
|
||||
result.error_message = f"构建提示词失败: {exc}"
|
||||
return False, result
|
||||
|
||||
prompt_preview = self._build_request_prompt_preview(request_messages)
|
||||
|
||||
def message_factory(_client: object) -> List[Message]:
|
||||
return request_messages
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
|
||||
if global_config.debug.show_replyer_prompt:
|
||||
logger.info(f"\nMaisaka 回复器提示词:\n{prompt_preview}\n")
|
||||
|
||||
started_at = time.perf_counter()
|
||||
try:
|
||||
generation_result = await self.express_model.generate_response_with_messages(message_factory=message_factory)
|
||||
except Exception as exc:
|
||||
logger.exception("Maisaka 回复器调用失败")
|
||||
result.error_message = str(exc)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
return False, result
|
||||
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
result.completion = LLMCompletionResult(
|
||||
request_prompt=prompt_preview,
|
||||
response_text=response_text,
|
||||
reasoning_text=generation_result.reasoning or "",
|
||||
model_name=generation_result.model_name or "",
|
||||
tool_calls=generation_result.tool_calls or [],
|
||||
)
|
||||
result.metrics = GenerationMetrics(
|
||||
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
|
||||
)
|
||||
|
||||
if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
if not result.success:
|
||||
result.error_message = "回复器返回了空内容"
|
||||
logger.warning("Maisaka 回复器返回了空内容")
|
||||
return False, result
|
||||
|
||||
logger.info(
|
||||
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
|
||||
f"总耗时毫秒={result.metrics.overall_ms} "
|
||||
f"已选表达编号={result.selected_expression_ids!r}"
|
||||
)
|
||||
result.text_fragments = [response_text]
|
||||
return True, result
|
||||
21
src/chat/replyer/maisaka_replyer_factory.py
Normal file
21
src/chat/replyer/maisaka_replyer_factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Type
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
def get_maisaka_replyer_class() -> Type[object]:
|
||||
"""根据配置返回 Maisaka replyer 类。"""
|
||||
generator_type = global_config.maisaka.replyer_generator_type
|
||||
if generator_type == "multi":
|
||||
from .maisaka_generator_multi import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
from .maisaka_generator import MaisakaReplyGenerator
|
||||
|
||||
return MaisakaReplyGenerator
|
||||
|
||||
|
||||
def get_maisaka_replyer_generator_type() -> str:
|
||||
"""返回当前配置的 Maisaka replyer 生成器类型。"""
|
||||
return global_config.maisaka.replyer_generator_type
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +1,14 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.maisaka_replyer_factory import (
|
||||
get_maisaka_replyer_class,
|
||||
get_maisaka_replyer_generator_type,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
logger = get_logger("ReplyerManager")
|
||||
@@ -23,14 +26,15 @@ class ReplyerManager:
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
replyer_type: str = "default",
|
||||
) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]:
|
||||
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
|
||||
"""按会话和 replyer 类型获取实例。"""
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取 replyer")
|
||||
return None
|
||||
|
||||
cache_key = f"{replyer_type}:{stream_id}"
|
||||
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
|
||||
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
|
||||
if cache_key in self._repliers:
|
||||
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
|
||||
return self._repliers[cache_key]
|
||||
@@ -47,10 +51,10 @@ class ReplyerManager:
|
||||
|
||||
try:
|
||||
if replyer_type == "maisaka":
|
||||
logger.info("[ReplyerManager] importing MaisakaReplyGenerator")
|
||||
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
|
||||
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
|
||||
maisaka_replyer_class = get_maisaka_replyer_class()
|
||||
|
||||
replyer = MaisakaReplyGenerator(
|
||||
replyer = maisaka_replyer_class(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("maibot_statistic")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user