merge: 同步 upstream/r-dev 并解决冲突

This commit is contained in:
DawnARC
2026-04-03 19:56:45 +08:00
186 changed files with 14212 additions and 6705 deletions

View File

@@ -1,25 +1,28 @@
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from rich.traceback import install
from sqlmodel import select
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import hashlib
import heapq
import Levenshtein
import random
import re
from src.common.logger import get_logger
from rich.traceback import install
from sqlmodel import select
import Levenshtein
from src.common.data_models.image_data_model import MaiEmoji
from src.common.database.database_model import Images, ImageType
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.utils.utils_image import ImageUtils
from src.prompt.prompt_manager import prompt_manager
from src.config.config import config_manager, global_config
from src.common.data_models.llm_service_data_models import LLMGenerationOptions, LLMImageOptions
from src.common.database.database import get_db_session, get_db_session_manual
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.config.config import config_manager, global_config
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from src.prompt.prompt_manager import prompt_manager
from src.services.llm_service import LLMServiceClient
logger = get_logger("emoji")
@@ -33,6 +36,171 @@ EMOJI_REGISTERED_DIR = DATA_DIR / "emoji_registered" # 已注册的表情包注
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册表情包系统内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
emoji_schema = {
"type": "object",
"description": "当前表情包的序列化信息,主要包含 file_hash、description、emotions 等字段。",
}
string_array_schema = {
"type": "array",
"items": {"type": "string"},
}
return registry.register_hook_specs(
[
HookSpec(
name="emoji.maisaka.before_select",
description="Maisaka 表情发送工具选择表情前触发,可改写情绪、上下文和采样参数,或中止本次选择。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=["stream_id", "requested_emotion", "reasoning", "context_texts", "sample_size"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.maisaka.after_select",
description="Maisaka 已选出表情后触发,可替换选中的表情哈希、补充匹配情绪,或中止发送。",
parameters_schema=build_object_schema(
{
"stream_id": {"type": "string", "description": "目标会话 ID。"},
"requested_emotion": {"type": "string", "description": "请求的目标情绪标签。"},
"reasoning": {"type": "string", "description": "本次发送表情的推理理由。"},
"context_texts": {
**string_array_schema,
"description": "最近聊天上下文文本列表。",
},
"sample_size": {"type": "integer", "description": "候选表情采样数量。"},
"selected_emoji": emoji_schema,
"selected_emoji_hash": {"type": "string", "description": "选中的表情哈希。"},
"matched_emotion": {"type": "string", "description": "最终命中的情绪标签。"},
"abort_message": {
"type": "string",
"description": "当 Hook 主动中止时可附带的失败提示。",
},
},
required=[
"stream_id",
"requested_emotion",
"reasoning",
"context_texts",
"sample_size",
"matched_emotion",
],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_description",
description="表情包描述生成并通过内容审查后触发,可改写描述文本或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前生成出的表情包描述。"},
"image_format": {"type": "string", "description": "表情图片格式。"},
},
required=["emoji", "description", "image_format"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
]
)
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, Any]]:
"""将表情包对象序列化为 Hook 可传输载荷。
Args:
emoji: 待序列化的表情包对象。
Returns:
Optional[Dict[str, Any]]: 序列化后的字典;当表情为空时返回 ``None``。
"""
if emoji is None:
return None
return {
"file_hash": str(emoji.file_hash or "").strip(),
"file_name": emoji.file_name,
"full_path": str(emoji.full_path),
"description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
"query_count": int(emoji.query_count),
}
def _normalize_string_list(raw_values: Any) -> List[str]:
"""将任意列表值规范化为字符串列表。
Args:
raw_values: 待规范化的原始值。
Returns:
List[str]: 去空白后的字符串列表。
"""
if not isinstance(raw_values, list):
return []
return [str(item).strip() for item in raw_values if str(item).strip()]
def _ensure_directories() -> None:
"""确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -642,6 +810,22 @@ class EmojiManager:
if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji),
description=description,
image_format=image_format,
)
if hook_result.aborted:
logger.info(f"[构建描述] 表情包描述被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
normalized_description = str(hook_result.kwargs.get("description", description) or "").strip()
if not normalized_description:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
description = normalized_description
target_emoji.description = description
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
return True, target_emoji
@@ -687,6 +871,23 @@ class EmojiManager:
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji),
description=target_emoji.description,
emotions=list(emotions),
)
if hook_result.aborted:
logger.info(f"[构建情感标签] 表情包情感标签被 Hook 中止注册: {target_emoji.file_name}")
return False, target_emoji
raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions)
if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
logger.info(f"[构建情感标签] 成功为表情包构建情感标签: {','.join(emotions)}")
target_emoji.emotion = emotions
return True, target_emoji

View File

@@ -0,0 +1,349 @@
"""Maisaka 表情工具内置能力。"""
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence
import random
from src.chat.message_receive.chat_manager import chat_manager
from src.cli.maisaka_cli_sender import CLI_PLATFORM_NAME, render_cli_message
from src.common.data_models.image_data_model import MaiEmoji
from src.common.data_models.llm_service_data_models import LLMGenerationOptions
from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils
from src.services import send_service
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
logger = get_logger("emoji_maisaka_tool")
@dataclass(slots=True)
class MaisakaEmojiSendResult:
"""Maisaka 表情发送结果。"""
success: bool
message: str
emoji_base64: str = ""
description: str = ""
emotions: list[str] = field(default_factory=list)
requested_emotion: str = ""
matched_emotion: str = ""
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
def _coerce_positive_int(value: Any, default: int) -> int:
"""将任意值安全转换为正整数。
Args:
value: 待转换的值。
default: 转换失败时使用的默认值。
Returns:
int: 规范化后的正整数。
"""
try:
normalized_value = int(value)
except (TypeError, ValueError):
return default
return normalized_value if normalized_value > 0 else default
def _normalize_context_texts(context_texts: Sequence[str] | None) -> list[str]:
"""清洗 Hook 和调用链传入的上下文文本列表。
Args:
context_texts: 原始上下文文本序列。
Returns:
list[str]: 过滤空白后的上下文文本列表。
"""
if not context_texts:
return []
return [str(item).strip() for item in context_texts if str(item).strip()]
def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
"""根据 Hook 返回值解析目标表情包对象。
Args:
raw_value: Hook 返回的 ``selected_emoji`` 或 ``selected_emoji_hash``。
Returns:
Optional[MaiEmoji]: 命中的表情包对象;未命中时返回 ``None``。
"""
raw_hash: str = ""
if isinstance(raw_value, dict):
raw_hash = str(raw_value.get("file_hash") or raw_value.get("hash") or "").strip()
elif isinstance(raw_value, str):
raw_hash = raw_value.strip()
if not raw_hash:
return None
for emoji in emoji_manager.emojis:
if emoji.file_hash == raw_hash:
return emoji
return None
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。"""
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
"""构建供情绪判断使用的最近上下文文本。"""
normalized_items = [str(item).strip() for item in context_texts if str(item).strip()]
if not normalized_items:
return ""
return "\n".join(normalized_items[-max_items:])
async def _select_emoji_with_llm(
*,
sampled_emojis: Sequence[MaiEmoji],
reasoning: str,
context_text: str,
) -> tuple[MaiEmoji, str]:
"""让模型在采样表情中选择更合适的情绪标签。"""
emotion_map: dict[str, list[MaiEmoji]] = {}
for emoji in sampled_emojis:
for emotion in _normalize_emotions(emoji):
emotion_map.setdefault(emotion, []).append(emoji)
available_emotions = list(emotion_map.keys())
if not available_emotions:
return random.choice(list(sampled_emojis)), ""
prompt = (
"你正在为聊天场景选择一个最合适的表情包情绪标签。\n"
f"发送原因:{reasoning or '辅助表达当前语气和情绪'}\n"
f"最近聊天记录:\n{context_text or '(暂无额外上下文)'}\n\n"
"可选情绪标签如下:\n"
f"{chr(10).join(available_emotions)}\n\n"
"请只返回一个最匹配的情绪标签,不要解释。"
)
try:
llm_result = await emoji_manager_emotion_judge_llm.generate_response(
prompt,
options=LLMGenerationOptions(temperature=0.3, max_tokens=60),
)
chosen_emotion = (llm_result.response or "").strip().strip("\"'")
except Exception as exc:
logger.warning(f"使用 LLM 选择表情情绪失败,将回退为随机选择: {exc}")
chosen_emotion = ""
if chosen_emotion and chosen_emotion in emotion_map:
return random.choice(emotion_map[chosen_emotion]), chosen_emotion
return random.choice(list(sampled_emojis)), ""
async def select_emoji_for_maisaka(
*,
requested_emotion: str = "",
reasoning: str = "",
context_texts: Sequence[str] | None = None,
sample_size: int = 30,
) -> tuple[MaiEmoji | None, str]:
"""为 Maisaka 选择一个合适的表情。"""
available_emojis = list(emoji_manager.emojis)
if not available_emojis:
return None, ""
normalized_requested_emotion = requested_emotion.strip()
if normalized_requested_emotion:
matched_emojis = [
emoji
for emoji in available_emojis
if normalized_requested_emotion.lower() in (emotion.lower() for emotion in _normalize_emotions(emoji))
]
if matched_emojis:
return random.choice(matched_emojis), normalized_requested_emotion
sampled_emojis = random.sample(
available_emojis,
min(max(sample_size, 1), len(available_emojis)),
)
context_text = _build_recent_context_text(context_texts or [])
return await _select_emoji_with_llm(
sampled_emojis=sampled_emojis,
reasoning=reasoning,
context_text=context_text,
)
async def send_emoji_for_maisaka(
*,
stream_id: str,
requested_emotion: str = "",
reasoning: str = "",
context_texts: Sequence[str] | None = None,
) -> MaisakaEmojiSendResult:
"""为 Maisaka 选择并发送一个表情。"""
normalized_requested_emotion = requested_emotion.strip()
normalized_reasoning = reasoning.strip()
normalized_context_texts = _normalize_context_texts(context_texts)
sample_size = 30
before_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.before_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
abort_message="表情选择已被 Hook 中止。",
)
if before_select_result.aborted:
abort_message = str(before_select_result.kwargs.get("abort_message") or "表情选择已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情选择已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
)
before_select_kwargs = before_select_result.kwargs
normalized_requested_emotion = str(
before_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
normalized_reasoning = str(before_select_kwargs.get("reasoning", normalized_reasoning) or "").strip()
if isinstance(before_select_kwargs.get("context_texts"), list):
normalized_context_texts = _normalize_context_texts(before_select_kwargs.get("context_texts"))
sample_size = _coerce_positive_int(before_select_kwargs.get("sample_size"), sample_size)
selected_emoji, matched_emotion = await select_emoji_for_maisaka(
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=normalized_context_texts,
sample_size=sample_size,
)
after_select_result = await _get_runtime_manager().invoke_hook(
"emoji.maisaka.after_select",
stream_id=stream_id,
requested_emotion=normalized_requested_emotion,
reasoning=normalized_reasoning,
context_texts=list(normalized_context_texts),
sample_size=sample_size,
selected_emoji=_serialize_emoji_for_hook(selected_emoji),
selected_emoji_hash=str(selected_emoji.file_hash or "").strip() if selected_emoji is not None else "",
matched_emotion=matched_emotion,
abort_message="表情发送已被 Hook 中止。",
)
if after_select_result.aborted:
abort_message = str(after_select_result.kwargs.get("abort_message") or "表情发送已被 Hook 中止。").strip()
return MaisakaEmojiSendResult(
success=False,
message=abort_message or "表情发送已被 Hook 中止。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
after_select_kwargs = after_select_result.kwargs
normalized_requested_emotion = str(
after_select_kwargs.get("requested_emotion", normalized_requested_emotion) or ""
).strip()
matched_emotion = str(after_select_kwargs.get("matched_emotion", matched_emotion) or "").strip()
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji_hash"))
if override_emoji is None:
override_emoji = _resolve_selected_emoji(after_select_kwargs.get("selected_emoji"))
if override_emoji is not None:
selected_emoji = override_emoji
if selected_emoji is None:
return MaisakaEmojiSendResult(
success=False,
message="当前表情包库中没有可用表情。",
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
raise ValueError("表情图片转换为 base64 失败")
except Exception as exc:
return MaisakaEmojiSendResult(
success=False,
message=f"发送表情包失败:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
try:
target_session = chat_manager.get_session_by_session_id(stream_id)
if target_session is not None and target_session.platform == CLI_PLATFORM_NAME:
preview_message = (
f"已发送表情包:{selected_emoji.description.strip()}"
if selected_emoji.description.strip()
else "[表情包]"
)
render_cli_message(preview_message)
sent = True
else:
sent = await send_service.emoji_to_stream(
emoji_base64=emoji_base64,
stream_id=stream_id,
storage_message=True,
set_reply=False,
reply_message=None,
)
except Exception as exc:
return MaisakaEmojiSendResult(
success=False,
message=f"发送表情包时发生异常:{exc}",
description=selected_emoji.description.strip(),
emotions=_normalize_emotions(selected_emoji),
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
description = selected_emoji.description.strip()
emotions = _normalize_emotions(selected_emoji)
if not sent:
return MaisakaEmojiSendResult(
success=False,
message="发送表情包失败。",
description=description,
emotions=emotions,
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)
emoji_manager.update_emoji_usage(selected_emoji)
success_message = (
f"已发送表情包:{description}(情绪:{', '.join(emotions)}"
if emotions
else f"已发送表情包:{description}"
)
return MaisakaEmojiSendResult(
success=True,
message=success_message,
emoji_base64=emoji_base64,
description=description,
emotions=emotions,
requested_emotion=normalized_requested_emotion,
matched_emotion=matched_emotion,
)

View File

@@ -44,6 +44,12 @@ class ImageManager:
logger.info("图片管理器初始化完成")
def _get_image_record(self, image_hash: str) -> Optional[Images]:
"""根据哈希获取图片记录。"""
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
return session.exec(statement).first()
async def get_image_description(
self,
*,
@@ -76,9 +82,8 @@ class ImageManager:
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
if record := self._get_image_record(hash_str):
if record.vlm_processed and record.description:
return record.description
except Exception as e:
logger.error(f"查询图片描述时发生错误: {e}")
@@ -86,12 +91,17 @@ class ImageManager:
if not image_bytes:
logger.warning("图片哈希值未找到,且未提供图片字节数据,返回无描述")
return ""
try:
await self.ensure_image_saved(image_bytes)
except Exception as e:
logger.error(f"保存图片文件时发生错误: {e}")
return ""
if not wait_for_build:
self._schedule_description_build(hash_str, image_bytes)
return ""
logger.info(f"图片描述未找到,哈希值: {hash_str},准备生成新描述")
try:
image = await self.save_image_and_process(image_bytes)
image = await self.build_image_description(image_bytes)
return image.description
except Exception as e:
logger.error(f"生成图片描述时发生错误: {e}")
@@ -120,7 +130,7 @@ class ImageManager:
"""
try:
logger.info(f"图片描述后台构建已开始,哈希值: {image_hash}")
await self.save_image_and_process(image_bytes)
await self.build_image_description(image_bytes)
logger.info(f"图片描述后台构建完成,哈希值: {image_hash}")
except Exception as exc:
logger.warning(f"图片描述后台构建失败,哈希值: {image_hash},错误: {exc}")
@@ -201,6 +211,7 @@ class ImageManager:
return False
record.description = image.description
record.last_used_time = datetime.now()
record.vlm_processed = image.vlm_processed
session.add(record)
logger.info(f"成功更新图片描述: {image.file_hash},新描述: {image.description}")
except Exception as e:
@@ -239,22 +250,13 @@ class ImageManager:
return False
return True
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
"""
保存图片并生成描述
Args:
image_bytes (bytes): 图片的字节数据
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
Exception: 如果在保存或处理过程中发生错误
"""
async def ensure_image_saved(self, image_bytes: bytes) -> MaiImage:
"""先保存图片记录,确保后续可以按哈希回填图片内容。"""
hash_str = hashlib.sha256(image_bytes).hexdigest()
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=hash_str).limit(1)
statement = select(Images).filter_by(image_hash=hash_str, image_type=ImageType.IMAGE).limit(1)
if record := session.exec(statement).first():
logger.info(f"图片已存在于数据库中,哈希值: {hash_str}")
record.last_used_time = datetime.now()
@@ -270,18 +272,38 @@ class ImageManager:
tmp_file_path = IMAGE_DIR / f"{hash_str}.tmp"
with tmp_file_path.open("wb") as f:
f.write(image_bytes)
mai_image = MaiImage(full_path=(IMAGE_DIR / f"{hash_str}.tmp"), image_bytes=image_bytes)
mai_image = MaiImage(full_path=tmp_file_path, image_bytes=image_bytes)
await mai_image.calculate_hash_format()
if not self.register_image_to_db(mai_image):
raise RuntimeError(f"保存图片记录到数据库失败: {hash_str}")
return mai_image
async def build_image_description(self, image_bytes: bytes) -> MaiImage:
"""在图片已保存的前提下生成或补齐图片描述。"""
mai_image = await self.ensure_image_saved(image_bytes)
if mai_image.vlm_processed and mai_image.description:
return mai_image
desc = await self._generate_image_description(image_bytes, mai_image.image_format)
mai_image.description = desc
mai_image.vlm_processed = True
try:
self.register_image_to_db(mai_image)
except Exception as e:
logger.error(f"保存新图片记录到数据库时发生错误: {e}")
raise e
if not self.update_image_description(mai_image):
raise RuntimeError(f"更新图片描述失败: {mai_image.file_hash}")
return mai_image
async def save_image_and_process(self, image_bytes: bytes) -> MaiImage:
"""
保存图片并生成描述
Args:
image_bytes (bytes): 图片的字节数据
Returns:
return (MaiImage): 包含图片信息的 MaiImage 对象
Raises:
Exception: 如果在保存或处理过程中发生错误
"""
return await self.build_image_description(image_bytes)
def cleanup_invalid_descriptions_in_db(self):
"""
清理数据库中无效的图片记录

View File

@@ -1,6 +1,7 @@
"""聊天消息入口与主链路调度。"""
from contextlib import suppress
from copy import deepcopy
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import os
import traceback
@@ -13,12 +14,15 @@ from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.platform_io.route_key_factory import RouteKeyFactory
from src.core.announcement_manager import global_announcement_manager
from src.plugin_runtime.component_query import component_query_service
from src.plugin_runtime.hook_payloads import deserialize_session_message, serialize_session_message
from src.plugin_runtime.hook_schema_utils import build_object_schema
from src.plugin_runtime.host.hook_dispatcher import HookDispatchResult
from src.plugin_runtime.host.hook_spec_registry import HookSpec, HookSpecRegistry
from .message import SessionMessage
from .chat_manager import chat_manager
from .message import SessionMessage
# 定义日志配置
@@ -29,7 +33,137 @@ PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..
logger = get_logger("chat")
def register_chat_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
"""注册聊天消息主链内置 Hook 规格。
Args:
registry: 目标 Hook 规格注册中心。
Returns:
List[HookSpec]: 实际注册的 Hook 规格列表。
"""
return registry.register_hook_specs(
[
HookSpec(
name="chat.receive.before_process",
description="在入站消息执行 `SessionMessage.process()` 之前触发,可拦截或改写消息。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前入站消息的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.receive.after_process",
description="在入站消息完成轻量预处理后触发,可改写文本、消息体或中止后续链路。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "已完成 `process()` 的序列化 SessionMessage。",
},
},
required=["message"],
),
default_timeout_ms=8000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.before_execute",
description="在命令匹配成功、实际执行前触发,可拦截命令或改写命令上下文。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命中的命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
},
required=["message", "command_name", "plugin_id", "matched_groups"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
HookSpec(
name="chat.command.after_execute",
description="在命令执行结束后触发,可调整返回文本和是否继续主链处理。",
parameters_schema=build_object_schema(
{
"message": {
"type": "object",
"description": "当前命令消息的序列化 SessionMessage。",
},
"command_name": {
"type": "string",
"description": "命令名称。",
},
"plugin_id": {
"type": "string",
"description": "命令所属插件 ID。",
},
"matched_groups": {
"type": "object",
"description": "命令正则命名捕获结果。",
},
"success": {
"type": "boolean",
"description": "命令执行是否成功。",
},
"response": {
"type": "string",
"description": "命令返回文本。",
},
"intercept_message_level": {
"type": "integer",
"description": "命令拦截等级。",
},
"continue_process": {
"type": "boolean",
"description": "命令执行后是否继续后续消息处理。",
},
},
required=[
"message",
"command_name",
"plugin_id",
"matched_groups",
"success",
"intercept_message_level",
"continue_process",
],
),
default_timeout_ms=5000,
allow_abort=False,
allow_kwargs_mutation=True,
),
]
)
class ChatBot:
"""聊天机器人入口协调器。"""
def __init__(self) -> None:
"""初始化聊天机器人入口。"""
@@ -44,6 +178,66 @@ class ChatBot:
self._started = True
@staticmethod
def _get_runtime_manager() -> Any:
"""获取插件运行时管理器。
Returns:
Any: 插件运行时管理器单例。
"""
from src.plugin_runtime.integration import get_plugin_runtime_manager
return get_plugin_runtime_manager()
@staticmethod
def _coerce_int(value: Any, default: int) -> int:
"""将任意值安全转换为整数。
Args:
value: 待转换的值。
default: 转换失败时的默认值。
Returns:
int: 转换后的整数结果。
"""
try:
return int(value)
except (TypeError, ValueError):
return default
async def _invoke_message_hook(
self,
hook_name: str,
message: SessionMessage,
**kwargs: Any,
) -> tuple[HookDispatchResult, SessionMessage]:
"""触发携带会话消息的命名 Hook。
Args:
hook_name: 目标 Hook 名称。
message: 当前会话消息。
**kwargs: 需要附带传递的额外参数。
Returns:
tuple[HookDispatchResult, SessionMessage]: Hook 聚合结果以及可能被改写后的消息对象。
"""
hook_result = await self._get_runtime_manager().invoke_hook(
hook_name,
message=serialize_session_message(message),
**kwargs,
)
mutated_message = message
raw_message = hook_result.kwargs.get("message")
if raw_message is not None:
try:
mutated_message = deserialize_session_message(raw_message)
except Exception as exc:
logger.warning(f"Hook {hook_name} 返回的 message 无法反序列化,已忽略: {exc}")
return hook_result, mutated_message
async def _process_commands(self, message: SessionMessage) -> tuple[bool, Optional[str], bool]:
"""使用统一组件注册表处理命令。
@@ -71,6 +265,25 @@ class ChatBot:
return False, None, True
message.is_command = True
before_result, message = await self._invoke_message_hook(
"chat.command.before_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
)
if before_result.aborted:
logger.info(f"命令 {command_name} 被 Hook 中止,跳过命令执行")
return True, None, False
hook_kwargs = before_result.kwargs
command_name = str(hook_kwargs.get("command_name", command_name) or command_name)
plugin_name = str(hook_kwargs.get("plugin_id", plugin_name) or plugin_name)
matched_groups = (
dict(hook_kwargs["matched_groups"])
if isinstance(hook_kwargs.get("matched_groups"), dict)
else dict(matched_groups)
)
# 获取插件配置
plugin_config = component_query_service.get_plugin_config(plugin_name)
@@ -82,27 +295,43 @@ class ChatBot:
plugin_config=plugin_config,
matched_groups=matched_groups,
)
self._mark_command_message(message, intercept_message_level)
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_name} - {e}")
continue_process = not bool(intercept_message_level)
except Exception as exc:
logger.error(f"执行命令时出错: {command_name} - {exc}")
logger.error(traceback.format_exc())
success = False
response = str(exc)
intercept_message_level = 1
continue_process = False
# 命令出错时,根据命令的拦截设置决定是否继续处理消息
return True, str(e), False # 出错时继续处理消息
after_result, message = await self._invoke_message_hook(
"chat.command.after_execute",
message,
command_name=command_name,
plugin_id=plugin_name,
matched_groups=dict(matched_groups),
success=success,
response=response,
intercept_message_level=intercept_message_level,
continue_process=continue_process,
)
after_kwargs = after_result.kwargs
success = bool(after_kwargs.get("success", success))
raw_response = after_kwargs.get("response", response)
response = None if raw_response is None else str(raw_response)
intercept_message_level = self._coerce_int(
after_kwargs.get("intercept_message_level", intercept_message_level),
intercept_message_level,
)
continue_process = bool(after_kwargs.get("continue_process", continue_process))
self._mark_command_message(message, intercept_message_level)
if success:
logger.info(f"命令执行成功: {command_name} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_name} - {response}")
return True, response, continue_process
return False, None, True
@@ -138,6 +367,17 @@ class ChatBot:
cmd_result: Optional[str],
continue_process: bool,
) -> bool:
"""处理命令链结果并决定是否终止主消息链。
Args:
message: 当前命令消息。
cmd_result: 命令响应文本。
continue_process: 是否继续后续主链处理。
Returns:
bool: ``True`` 表示已经终止后续主链。
"""
if continue_process:
return False
@@ -145,9 +385,18 @@ class ChatBot:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return True
async def handle_notice_message(self, message: SessionMessage):
async def handle_notice_message(self, message: SessionMessage) -> bool:
"""处理通知类消息。
Args:
message: 当前通知消息。
Returns:
bool: 当前消息是否为通知消息。
"""
if message.message_id != "notice":
return
return False
message.is_notify = True
logger.debug("notice消息")
@@ -203,9 +452,12 @@ class ChatBot:
return True
async def echo_message_process(self, raw_data: Dict[str, Any]) -> None:
"""处理消息回送 ID 对应关系。
Args:
raw_data: 平台适配器上报的原始回送载荷。
"""
用于专门处理回送消息ID的函数
"""
message_data: Dict[str, Any] = raw_data.get("content", {})
if not message_data:
return
@@ -218,18 +470,10 @@ class ChatBot:
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
heart_flow模式使用思维流系统进行回复
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
- 消息过滤
- 记忆激活
- 意愿计算
- 消息生成和发送
- 表情包处理
- 性能计时
"""处理统一格式的入站消息字典。
Args:
message_data: 适配器整理后的统一消息字典。
"""
try:
# 确保所有任务已启动
@@ -253,7 +497,13 @@ class ChatBot:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
async def receive_message(self, message: SessionMessage):
async def receive_message(self, message: SessionMessage) -> None:
"""处理单条入站会话消息。
Args:
message: 待处理的会话消息。
"""
try:
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -272,6 +522,19 @@ class ChatBot:
)
message.session_id = session_id # 正确初始化session_id
before_process_result, message = await self._invoke_message_hook(
"chat.receive.before_process",
message,
)
if before_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理前被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
additional_config = message.message_info.additional_config
if isinstance(additional_config, dict):
account_id, scope = RouteKeyFactory.extract_components(additional_config)
# TODO: 修复事件预处理部分
# continue_flag, modified_message = await events_manager.handle_mai_events(
@@ -286,14 +549,24 @@ class ChatBot:
# if await self.handle_notice_message(message):
# pass
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
if global_config.maisaka.direct_image_input:
message.maisaka_original_raw_message = deepcopy(message.raw_message) # type: ignore[attr-defined]
# 处理消息内容,识别表情包等二进制数据并转化为文本描述
# 如果 Maisaka 需要直接消费图片,会在后续构建 prompt 时按需回填图片二进制数据,
# 这里不再复制整条原始消息。
# 入站主链优先保证消息尽快入队,避免图片、表情包、语音分析阻塞适配器超时。
await message.process(
enable_heavy_media_analysis=False,
enable_voice_transcription=False,
)
after_process_result, message = await self._invoke_message_hook(
"chat.receive.after_process",
message,
)
if after_process_result.aborted:
logger.info(f"消息 {message.message_id} 在预处理后被 Hook 中止")
return
group_info = message.message_info.group_info
user_info = message.message_info.user_info
# 平台层的 @ 检测由底层 is_mentioned_bot_in_message 统一处理;此处不做用户名硬编码匹配

View File

@@ -1,11 +1,10 @@
import asyncio
from asyncio import Task
from typing import Dict, List, Sequence, Tuple
from rich.traceback import install
from sqlmodel import select
import asyncio
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
@@ -36,6 +35,102 @@ class MsgIDMapping:
class SessionMessage(MaiMessage):
#便于调试的打印函数
def __str__(self) -> str:
"""返回适合日志输出的消息摘要。"""
return self.to_debug_string()
def __repr__(self) -> str:
"""返回适合调试场景的消息摘要。"""
return self.to_debug_string()
def to_debug_string(self) -> str:
"""构建包含引用信息的调试字符串。
Returns:
str: 适合记录日志的消息摘要。
"""
user_info = self.message_info.user_info
group_info = self.message_info.group_info
chat_type = "group" if group_info else "private"
group_id = group_info.group_id if group_info else None
group_name = group_info.group_name if group_info else None
component_summaries = [self._summarize_component(component) for component in self.raw_message.components]
raw_components = ", ".join(component_summaries) if component_summaries else "empty"
return (
"SessionMessage("
f"message_id={self.message_id!r}, "
f"platform={self.platform!r}, "
f"chat_type={chat_type!r}, "
f"group_id={group_id!r}, "
f"group_name={group_name!r}, "
f"user_id={user_info.user_id!r}, "
f"user_nickname={user_info.user_nickname!r}, "
f"user_cardname={user_info.user_cardname!r}, "
f"reply_to={self.reply_to!r}, "
f"processed_plain_text={self._truncate_text(self.processed_plain_text)}, "
f"raw_components=[{raw_components}]"
")"
)
@staticmethod
def _truncate_text(text: str | None, max_length: int = 120) -> str:
"""截断较长文本,避免日志过长。
Args:
text: 原始文本。
max_length: 最大保留长度。
Returns:
str: 截断后的文本表示。
"""
if text is None:
return "None"
normalized_text = text.replace("\r", "\\r").replace("\n", "\\n")
if len(normalized_text) <= max_length:
return repr(normalized_text)
return repr(f"{normalized_text[:max_length]}...")
def _summarize_component(self, component: StandardMessageComponents) -> str:
"""生成单个消息组件的调试摘要。
Args:
component: 消息组件对象。
Returns:
str: 组件摘要文本。
"""
if isinstance(component, TextComponent):
return f"Text(text={self._truncate_text(component.text, 80)})"
if isinstance(component, ImageComponent):
return f"Image(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, EmojiComponent):
return f"Emoji(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, AtComponent):
target_name = component.target_user_cardname or component.target_user_nickname or component.target_user_id
return f"At(target={target_name!r})"
if isinstance(component, VoiceComponent):
return f"Voice(content={self._truncate_text(component.content or None, 60)})"
if isinstance(component, ReplyComponent):
sender_name = (
component.target_message_sender_cardname
or component.target_message_sender_nickname
or component.target_message_sender_id
)
return (
"Reply("
f"target_message_id={component.target_message_id!r}, "
f"target_sender={sender_name!r}, "
f"target_content={self._truncate_text(component.target_message_content, 80)}"
")"
)
if isinstance(component, ForwardNodeComponent):
return f"ForwardNode(count={len(component.forward_components)})"
return f"{component.__class__.__name__}"
#便于调试的打印函数end
async def process(
self,
*,

View File

@@ -18,28 +18,29 @@ install(extra_lines=3)
logger = get_logger("sender")
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str]]] = None
_webui_chat_broadcaster: Optional[Tuple[Any, Optional[str], Optional[str]]] = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str]]:
def get_webui_chat_broadcaster() -> Tuple[Any, Optional[str], Optional[str]]:
"""获取 WebUI 聊天室广播器。
Returns:
Tuple[Any, Optional[str]]: ``(chat_manager, platform_name)`` 元组;
Tuple[Any, Optional[str], Optional[str]]: ``(chat_manager, platform_name, default_group_id)`` 元组;
若 WebUI 相关模块不可用,则元素会退化为 ``None``。
"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.routers.chat import WEBUI_CHAT_PLATFORM, chat_manager
from src.webui.routers.chat.service import WEBUI_CHAT_GROUP_ID
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM, WEBUI_CHAT_GROUP_ID)
except ImportError:
_webui_chat_broadcaster = (None, None)
_webui_chat_broadcaster = (None, None, None)
return _webui_chat_broadcaster
@@ -76,7 +77,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
chat_manager, webui_platform, default_group_id = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
@@ -97,8 +98,9 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
await chat_manager.broadcast_to_group(
group_id=group_id or default_group_id or "",
message={
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
@@ -110,7 +112,7 @@ async def _send_message(message: SessionMessage, show_log: bool = True) -> bool:
"avatar": None,
"is_bot": True,
},
}
},
)
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库

View File

@@ -35,7 +35,7 @@ from src.services import llm_service as llm_api
from src.chat.logger.plan_reply_logger import PlanReplyLogger
from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt
from src.learners.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon
from src.learners.jargon_explainer_old import explain_jargon_in_context
from src.chat.utils.common_utils import TempMethodsExpression
init_memory_retrieval_sys()
@@ -688,39 +688,41 @@ class DefaultReplyer:
return None
def get_chat_prompt_for_chat(self, chat_id: str) -> str:
"""
根据聊天流ID获取匹配的额外prompt仅匹配group类型
Args:
chat_id: 聊天流ID哈希值
Returns:
str: 匹配的额外prompt内容如果没有匹配则返回空字符串
"""
if not global_config.experimental.chat_prompts:
"""根据聊天流 ID 获取匹配的额外 prompt。"""
if not global_config.chat.chat_prompts:
return ""
for chat_prompt_str in global_config.experimental.chat_prompts:
if not isinstance(chat_prompt_str, str):
for chat_prompt_item in global_config.chat.chat_prompts:
if hasattr(chat_prompt_item, "rule_type") and hasattr(chat_prompt_item, "prompt"):
if str(chat_prompt_item.rule_type or "").strip() != "group":
continue
config_chat_id = self._build_chat_uid(
str(chat_prompt_item.platform or "").strip(),
str(chat_prompt_item.item_id or "").strip(),
True,
)
prompt_content = str(chat_prompt_item.prompt or "").strip()
if config_chat_id == chat_id and prompt_content:
logger.debug(f"匹配到群聊 prompt 配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
continue
# 解析配置字符串检查类型是否为group
parts = chat_prompt_str.split(":", 3)
if len(parts) != 4:
if not isinstance(chat_prompt_item, str):
continue
stream_type = parts[2]
# 只匹配group类型
if stream_type != "group":
# 兼容旧格式的 platform:id:type:prompt 配置字符串。
parts = chat_prompt_item.split(":", 3)
if len(parts) != 4 or parts[2] != "group":
continue
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_str)
result = self._parse_chat_prompt_config_to_chat_id(chat_prompt_item)
if result is None:
continue
config_chat_id, prompt_content = result
if config_chat_id == chat_id:
logger.debug(f"匹配到群聊prompt配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
logger.debug(f"匹配到群聊 prompt 配置chat_id: {chat_id}, prompt: {prompt_content[:50]}...")
return prompt_content
return ""

View File

@@ -0,0 +1,453 @@
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import random
import time
from sqlmodel import select
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.data_models.reply_generation_data_models import (
GenerationMetrics,
LLMCompletionResult,
ReplyGenerationResult,
)
from src.common.logger import get_logger
from src.common.prompt_i18n import load_prompt
from src.config.config import global_config
from src.core.types import ActionInfo
from src.services.llm_service import LLMServiceClient
from src.chat.message_receive.message import SessionMessage
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.maisaka.context_messages import AssistantMessage, LLMContextMessage, ReferenceMessage, SessionBackedMessage, ToolResultMessage
from src.maisaka.message_adapter import parse_speaker_content
logger = get_logger("replyer")
@dataclass
class MaisakaReplyContext:
"""Maisaka replyer 使用的回复上下文。"""
expression_habits: str = ""
selected_expression_ids: List[int] = field(default_factory=list)
@dataclass
class _ExpressionRecord:
"""表达方式的轻量记录。"""
expression_id: Optional[int]
situation: str
style: str
class MaisakaReplyGenerator:
"""生成 Maisaka 的最终可见回复。"""
def __init__(
self,
chat_stream: Optional[BotChatSession] = None,
request_type: str = "maisaka_replyer",
) -> None:
self.chat_stream = chat_stream
self.request_type = request_type
self.express_model = LLMServiceClient(
task_name="replyer",
request_type=request_type,
)
self._personality_prompt = self._build_personality_prompt()
def _build_personality_prompt(self) -> str:
"""构建 replyer 使用的人设描述。"""
try:
bot_name = global_config.bot.nickname
alias_names = global_config.bot.alias_names
bot_aliases = f",也有人叫你{','.join(alias_names)}" if alias_names else ""
prompt_personality = global_config.personality.personality
if (
hasattr(global_config.personality, "states")
and global_config.personality.states
and hasattr(global_config.personality, "state_probability")
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
return f"你的名字是{bot_name}{bot_aliases},你{prompt_personality};"
except Exception as exc:
logger.warning(f"构建 Maisaka 人设提示词失败: {exc}")
return "你的名字是麦麦,你是一个活泼可爱的 AI 助手。"
@staticmethod
def _normalize_content(content: str, limit: int = 500) -> str:
normalized = " ".join((content or "").split())
if len(normalized) > limit:
return normalized[:limit] + "..."
return normalized
@staticmethod
def _extract_visible_assistant_reply(message: AssistantMessage) -> str:
del message
return ""
def _extract_guided_bot_reply(self, message: SessionBackedMessage) -> str:
speaker_name, body = parse_speaker_content(message.processed_plain_text.strip())
bot_nickname = global_config.bot.nickname.strip() or "Bot"
if speaker_name == bot_nickname:
return self._normalize_content(body.strip())
return ""
@staticmethod
def _split_user_message_segments(raw_content: str) -> List[tuple[Optional[str], str]]:
"""按说话人拆分用户消息。"""
segments: List[tuple[Optional[str], str]] = []
current_speaker: Optional[str] = None
current_lines: List[str] = []
for raw_line in raw_content.splitlines():
speaker_name, content_body = parse_speaker_content(raw_line)
if speaker_name is not None:
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
current_speaker = speaker_name
current_lines = [content_body]
continue
current_lines.append(raw_line)
if current_lines:
segments.append((current_speaker, "\n".join(current_lines)))
return segments
def _build_system_prompt(
self,
reply_reason: str,
expression_habits: str = "",
) -> str:
"""构建 Maisaka replyer 使用的系统提示词。"""
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
system_prompt = load_prompt(
"maisaka_replyer",
bot_name=global_config.bot.nickname,
time_block=f"当前时间:{current_time}",
identity=self._personality_prompt,
reply_style=global_config.personality.reply_style,
)
except Exception:
system_prompt = "你是一个友好的 AI 助手,请根据聊天记录自然回复。"
extra_sections: List[str] = []
if expression_habits.strip():
extra_sections.append(expression_habits.strip())
if reply_reason.strip():
extra_sections.append(f"【回复信息参考】\n{reply_reason}")
if not extra_sections:
return system_prompt
return f"{system_prompt}\n\n" + "\n\n".join(extra_sections)
def _build_reply_instruction(self) -> str:
"""构建追加在上下文末尾的回复指令。"""
return "请基于以上逐条对话消息,自然地继续回复。直接输出你要说的话,不要额外解释。"
def _build_history_messages(self, chat_history: List[LLMContextMessage]) -> List[Message]:
"""将 replyer 上下文拆成多条 LLM 消息。"""
bot_nickname = global_config.bot.nickname.strip() or "Bot"
default_user_name = global_config.maisaka.user_name.strip() or "User"
messages: List[Message] = []
for message in chat_history:
if isinstance(message, (ReferenceMessage, ToolResultMessage)):
continue
if isinstance(message, SessionBackedMessage):
guided_reply = self._extract_guided_bot_reply(message)
if guided_reply:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(guided_reply).build()
)
continue
for speaker_name, content_body in self._split_user_message_segments(message.processed_plain_text):
content = self._normalize_content(content_body)
if not content:
continue
visible_speaker = speaker_name or default_user_name
if visible_speaker == bot_nickname:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(content).build()
)
continue
user_content = f"[{visible_speaker}]{content}"
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(user_content).build())
continue
if isinstance(message, AssistantMessage):
visible_reply = self._extract_visible_assistant_reply(message)
if visible_reply:
messages.append(
MessageBuilder().set_role(RoleType.Assistant).add_text_content(visible_reply).build()
)
return messages
def _build_request_messages(
self,
chat_history: List[LLMContextMessage],
reply_reason: str,
expression_habits: str = "",
) -> List[Message]:
"""构建发给大模型的消息列表。"""
messages: List[Message] = []
system_prompt = self._build_system_prompt(
reply_reason=reply_reason,
expression_habits=expression_habits,
)
instruction = self._build_reply_instruction()
messages.append(MessageBuilder().set_role(RoleType.System).add_text_content(system_prompt).build())
messages.extend(self._build_history_messages(chat_history))
messages.append(MessageBuilder().set_role(RoleType.User).add_text_content(instruction).build())
return messages
@staticmethod
def _build_request_prompt_preview(messages: List[Message]) -> str:
"""将消息列表转为便于调试的文本预览。"""
preview_lines: List[str] = []
for message in messages:
role_name = message.role.value.capitalize()
preview_lines.append(f"{role_name}: {message.get_text_content()}")
return "\n\n".join(preview_lines)
def _resolve_session_id(self, stream_id: Optional[str]) -> str:
"""解析当前回复使用的会话 ID。"""
if stream_id:
return stream_id
if self.chat_stream is not None:
return self.chat_stream.session_id
return ""
async def _build_reply_context(
self,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
stream_id: Optional[str],
) -> MaisakaReplyContext:
"""在 replyer 内部构建表达习惯和黑话解释。"""
session_id = self._resolve_session_id(stream_id)
if not session_id:
logger.warning("构建 Maisaka 回复上下文失败:缺少会话标识")
return MaisakaReplyContext()
expression_habits, selected_expression_ids = self._build_expression_habits(
session_id=session_id,
chat_history=chat_history,
reply_message=reply_message,
reply_reason=reply_reason,
)
return MaisakaReplyContext(
expression_habits=expression_habits,
selected_expression_ids=selected_expression_ids,
)
def _build_expression_habits(
self,
session_id: str,
chat_history: List[LLMContextMessage],
reply_message: Optional[SessionMessage],
reply_reason: str,
) -> tuple[str, List[int]]:
"""查询并格式化适合当前会话的表达习惯。"""
del chat_history
del reply_message
del reply_reason
expression_records = self._load_expression_records(session_id)
if not expression_records:
return "", []
lines: List[str] = []
selected_ids: List[int] = []
for expression in expression_records:
if expression.expression_id is not None:
selected_ids.append(expression.expression_id)
lines.append(f"- 当{expression.situation}时,可以自然地用{expression.style}这种表达习惯。")
block = "【表达习惯参考】\n" + "\n".join(lines)
logger.info(
f"已构建 Maisaka 表达习惯: 会话标识={session_id} "
f"数量={len(selected_ids)} 表达编号={selected_ids!r}"
)
return block, selected_ids
def _load_expression_records(self, session_id: str) -> List[_ExpressionRecord]:
"""提取表达方式静态数据,避免 detached ORM 对象。"""
with get_db_session(auto_commit=False) as session:
query = select(Expression).where(Expression.rejected.is_(False)) # type: ignore[attr-defined]
if global_config.expression.expression_checked_only:
query = query.where(Expression.checked.is_(True)) # type: ignore[attr-defined]
query = query.where(
(Expression.session_id == session_id) | (Expression.session_id.is_(None)) # type: ignore[attr-defined]
).order_by(Expression.count.desc(), Expression.last_active_time.desc()) # type: ignore[attr-defined]
expressions = session.exec(query.limit(5)).all()
return [
_ExpressionRecord(
expression_id=expression.id,
situation=expression.situation,
style=expression.style,
)
for expression in expressions
]
async def generate_reply_with_context(
self,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[object]] = None,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[SessionMessage] = None,
reply_time_point: Optional[float] = None,
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
log_reply: bool = True,
chat_history: Optional[List[LLMContextMessage]] = None,
expression_habits: str = "",
selected_expression_ids: Optional[List[int]] = None,
) -> Tuple[bool, ReplyGenerationResult]:
"""结合上下文生成 Maisaka 的最终可见回复。"""
del available_actions
del chosen_actions
del extra_info
del from_plugin
del log_reply
del reply_time_point
del think_level
del unknown_words
result = ReplyGenerationResult()
if chat_history is None:
result.error_message = "聊天历史为空"
return False, result
logger.info(
f"Maisaka 回复器开始生成: 会话流标识={stream_id} 回复原因={reply_reason!r} "
f"历史消息数={len(chat_history)} 目标消息编号="
f"{reply_message.message_id if reply_message else None}"
)
filtered_history = [
message
for message in chat_history
if not isinstance(message, (ReferenceMessage, ToolResultMessage))
]
logger.debug(f"Maisaka 回复器过滤后历史消息数={len(filtered_history)}")
# Validate that express_model is properly initialized
if self.express_model is None:
logger.error("Maisaka 回复器的回复模型未初始化")
result.error_message = "回复模型尚未初始化"
return False, result
try:
reply_context = await self._build_reply_context(
chat_history=filtered_history,
reply_message=reply_message,
reply_reason=reply_reason or "",
stream_id=stream_id,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建回复上下文失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建回复上下文失败: {exc}"
return False, result
merged_expression_habits = expression_habits.strip() or reply_context.expression_habits
result.selected_expression_ids = (
list(selected_expression_ids)
if selected_expression_ids is not None
else list(reply_context.selected_expression_ids)
)
logger.info(
f"Maisaka 回复上下文构建完成: 会话流标识={stream_id} "
f"已选表达编号={result.selected_expression_ids!r}"
)
try:
request_messages = self._build_request_messages(
chat_history=filtered_history,
reply_reason=reply_reason or "",
expression_habits=merged_expression_habits,
)
except Exception as exc:
import traceback
logger.error(f"Maisaka 回复器构建提示词失败: {exc}\n{traceback.format_exc()}")
result.error_message = f"构建提示词失败: {exc}"
return False, result
prompt_preview = self._build_request_prompt_preview(request_messages)
def message_factory(_client: object) -> List[Message]:
return request_messages
result.completion.request_prompt = prompt_preview
if global_config.debug.show_replyer_prompt:
logger.info(f"\nMaisaka 回复器提示词:\n{prompt_preview}\n")
started_at = time.perf_counter()
try:
generation_result = await self.express_model.generate_response_with_messages(message_factory=message_factory)
except Exception as exc:
logger.exception("Maisaka 回复器调用失败")
result.error_message = str(exc)
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
)
return False, result
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
result.completion = LLMCompletionResult(
request_prompt=prompt_preview,
response_text=response_text,
reasoning_text=generation_result.reasoning or "",
model_name=generation_result.model_name or "",
tool_calls=generation_result.tool_calls or [],
)
result.metrics = GenerationMetrics(
overall_ms=round((time.perf_counter() - started_at) * 1000, 2),
)
if global_config.debug.show_replyer_reasoning and result.completion.reasoning_text:
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
if not result.success:
result.error_message = "回复器返回了空内容"
logger.warning("Maisaka 回复器返回了空内容")
return False, result
logger.info(
f"Maisaka 回复器生成成功: 回复文本={response_text!r} "
f"总耗时毫秒={result.metrics.overall_ms} "
f"已选表达编号={result.selected_expression_ids!r}"
)
result.text_fragments = [response_text]
return True, result

View File

@@ -0,0 +1,21 @@
from typing import Type
from src.config.config import global_config
def get_maisaka_replyer_class() -> Type[object]:
"""根据配置返回 Maisaka replyer 类。"""
generator_type = global_config.maisaka.replyer_generator_type
if generator_type == "multi":
from .maisaka_generator_multi import MaisakaReplyGenerator
return MaisakaReplyGenerator
from .maisaka_generator import MaisakaReplyGenerator
return MaisakaReplyGenerator
def get_maisaka_replyer_generator_type() -> str:
"""返回当前配置的 Maisaka replyer 生成器类型。"""
return global_config.maisaka.replyer_generator_type

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.replyer.maisaka_replyer_factory import (
get_maisaka_replyer_class,
get_maisaka_replyer_generator_type,
)
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
from src.chat.replyer.private_generator import PrivateReplyer
logger = get_logger("ReplyerManager")
@@ -23,14 +26,15 @@ class ReplyerManager:
chat_id: Optional[str] = None,
request_type: str = "replyer",
replyer_type: str = "default",
) -> Optional["DefaultReplyer | MaisakaReplyGenerator | PrivateReplyer"]:
) -> Optional["DefaultReplyer | PrivateReplyer | Any"]:
"""按会话和 replyer 类型获取实例。"""
stream_id = chat_stream.session_id if chat_stream else chat_id
if not stream_id:
logger.warning("[ReplyerManager] 缺少 stream_id无法获取 replyer")
return None
cache_key = f"{replyer_type}:{stream_id}"
generator_type = get_maisaka_replyer_generator_type() if replyer_type == "maisaka" else ""
cache_key = f"{replyer_type}:{generator_type}:{stream_id}"
if cache_key in self._repliers:
logger.info(f"[ReplyerManager] 命中缓存 replyer: cache_key={cache_key}")
return self._repliers[cache_key]
@@ -47,10 +51,10 @@ class ReplyerManager:
try:
if replyer_type == "maisaka":
logger.info("[ReplyerManager] importing MaisakaReplyGenerator")
from src.chat.replyer.maisaka_generator import MaisakaReplyGenerator
logger.info(f"[ReplyerManager] 选择 MaisakaReplyGenerator: generator_type={generator_type}")
maisaka_replyer_class = get_maisaka_replyer_class()
replyer = MaisakaReplyGenerator(
replyer = maisaka_replyer_class(
chat_stream=target_stream,
request_type=request_type,
)

View File

@@ -15,7 +15,6 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, ModelUsage, OnlineTime, ToolRecord
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
from src.config.config import global_config
logger = get_logger("maibot_statistic")