feat:修复gemini tool问题,简化表情包识别,修复非多模态plan图片识别

This commit is contained in:
SengokuCola
2026-04-05 14:50:52 +08:00
parent 18d48e0145
commit d82b37a08f
18 changed files with 533 additions and 158 deletions

View File

@@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import asyncio import asyncio
import hashlib import hashlib
@@ -129,24 +129,6 @@ def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
allow_abort=True, allow_abort=True,
allow_kwargs_mutation=True, allow_kwargs_mutation=True,
), ),
HookSpec(
name="emoji.register.after_build_emotion",
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
parameters_schema=build_object_schema(
{
"emoji": emoji_schema,
"description": {"type": "string", "description": "当前表情包描述。"},
"emotions": {
**string_array_schema,
"description": "当前生成出的情绪标签列表。",
},
},
required=["emoji", "description", "emotions"],
),
default_timeout_ms=5000,
allow_abort=True,
allow_kwargs_mutation=True,
),
] ]
) )
@@ -181,7 +163,7 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
"file_name": emoji.file_name, "file_name": emoji.file_name,
"full_path": str(emoji.full_path), "full_path": str(emoji.full_path),
"description": emoji.description, "description": emoji.description,
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()], "emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
"query_count": int(emoji.query_count), "query_count": int(emoji.query_count),
} }
@@ -201,6 +183,39 @@ def _normalize_string_list(raw_values: Any) -> List[str]:
return [str(item).strip() for item in raw_values if str(item).strip()] return [str(item).strip() for item in raw_values if str(item).strip()]
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
"""将文本或标签列表转为去重的情绪标签列表。"""
if isinstance(raw_values, str):
if not raw_values:
return []
parts = re.split(r"[,,、;\s]+", raw_values.strip())
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
elif isinstance(raw_values, list):
normalized_tags: List[str] = []
for value in raw_values:
normalized_tags.extend(_normalize_emoji_tag_text(value))
else:
return []
deduped_tags: List[str] = []
seen: Set[str] = set()
for tag in normalized_tags:
normalized_tag = tag.strip()
if not normalized_tag:
continue
lowered = normalized_tag.lower()
if lowered in seen:
continue
seen.add(lowered)
deduped_tags.append(normalized_tag)
return deduped_tags
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
"""获取兼容旧数据的表情包情绪标签。"""
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
def _ensure_directories() -> None: def _ensure_directories() -> None:
"""确保表情包相关目录存在""" """确保表情包相关目录存在"""
EMOJI_DIR.mkdir(parents=True, exist_ok=True) EMOJI_DIR.mkdir(parents=True, exist_ok=True)
@@ -269,20 +284,23 @@ class EmojiManager:
Exception: 如果在缓存表情包的过程中发生错误,则抛出异常 Exception: 如果在缓存表情包的过程中发生错误,则抛出异常
""" """
# 先查找 # 先查找
if emoji_hash is None and emoji_bytes is not None: if emoji_hash is None:
if emoji_bytes is None:
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest() emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
else:
emoji_hash = emoji_hash
if not emoji_hash:
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
if emoji := self.get_emoji_by_hash(emoji_hash): if emoji := self.get_emoji_by_hash(emoji_hash):
return emoji.description, emoji.emotion or [] return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
try: try:
with get_db_session() as session: with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1) statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if result := session.exec(statement).first(): if result := session.exec(statement).first():
return result.description, result.emotion.split(",") if result.emotion else [] cached_description = result.description or result.emotion or ""
cached_emotions = _normalize_emoji_tag_text(cached_description)
return (
cached_description,
cached_emotions,
)
except Exception as e: except Exception as e:
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述") logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
@@ -407,24 +425,19 @@ class EmojiManager:
logger.error("Build emoji description failed") logger.error("Build emoji description failed")
return None return None
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji) # 情绪标签已在 build_emoji_description 内一次性生成,这里仅做兼容性兜底处理
if not success_emotion:
logger.error("Build emoji emotion labels failed")
return None
with get_db_session() as session: with get_db_session() as session:
try: try:
statement = select(Images).filter_by(image_hash=new_emoji.file_hash, image_type=ImageType.EMOJI).limit(1) statement = select(Images).filter_by(image_hash=new_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first(): if image_record := session.exec(statement).first():
image_record.full_path = str(new_emoji.full_path) image_record.full_path = str(new_emoji.full_path)
image_record.description = new_emoji.description image_record.description = new_emoji.description
image_record.emotion = ",".join(new_emoji.emotion) if new_emoji.emotion else None
image_record.no_file_flag = False image_record.no_file_flag = False
image_record.is_banned = False image_record.is_banned = False
session.add(image_record) session.add(image_record)
except Exception as exc: except Exception as exc:
logger.error(f"Update cached emoji description failed: {exc}") logger.error(f"Update cached emoji description failed: {exc}")
return new_emoji.description, new_emoji.emotion or [] return new_emoji.description, _get_emoji_emotions(new_emoji)
def load_emojis_from_db(self) -> None: def load_emojis_from_db(self) -> None:
@@ -512,7 +525,6 @@ class EmojiManager:
existing_record.is_banned = False existing_record.is_banned = False
existing_record.full_path = str(emoji.full_path) existing_record.full_path = str(emoji.full_path)
existing_record.description = emoji.description existing_record.description = emoji.description
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
existing_record.query_count = emoji.query_count existing_record.query_count = emoji.query_count
existing_record.last_used_time = emoji.last_used_time existing_record.last_used_time = emoji.last_used_time
existing_record.register_time = emoji.register_time existing_record.register_time = emoji.register_time
@@ -639,7 +651,7 @@ class EmojiManager:
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1) statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first(): if image_record := session.exec(statement).first():
image_record.description = emoji.description image_record.description = emoji.description
image_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None image_record.emotion = None
session.add(image_record) session.add(image_record)
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}") logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
else: else:
@@ -734,7 +746,11 @@ class EmojiManager:
selected_emoji, similarity = random.choice(top_emojis) selected_emoji, similarity = random.choice(top_emojis)
self.update_emoji_usage(selected_emoji) self.update_emoji_usage(selected_emoji)
logger.info( logger.info(
f"[获取表情包] 为[{emotion_label}]选中表情包: {selected_emoji.file_name}({selected_emoji.emotion}),相似度: {similarity:.4f}" "[获取表情包] 为[%s]选中表情包: %s(%s),相似度: %.4f",
emotion_label,
selected_emoji.file_name,
",".join(_get_emoji_emotions(selected_emoji)),
similarity,
) )
return selected_emoji return selected_emoji
@@ -833,7 +849,11 @@ class EmojiManager:
except Exception as e: except Exception as e:
logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}") logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}")
return False, target_emoji return False, target_emoji
prompt: str = "这是一个动态图表情包每一张图代表了动态图的某一帧黑色背景代表透明简短描述一下表情包表达的情感和内容从互联网梗、meme的角度去分析精简回答" prompt: str = (
"这是一个动态图表情包,每一张图代表了动态图的一帧。"
"请只返回该表情包常见的情绪/场景标签,最多 5 个,"
"使用逗号分隔,标签可为中文或英文,不要附带解释。"
)
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes) image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
description_result = await emoji_manager_vlm.generate_response_for_image( description_result = await emoji_manager_vlm.generate_response_for_image(
prompt, prompt,
@@ -843,7 +863,10 @@ class EmojiManager:
) )
description = description_result.response description = description_result.response
else: else:
prompt: str = "这是一个表情包请详细描述一下表情包所表达的情感和内容简短描述细节从互联网梗、meme的角度去分析精简回答" prompt: str = (
"这是一个表情包图片,请提取该表情主要表达的情绪或语气标签,"
"最多 5 个,使用逗号分隔,返回纯文本标签列表,不要解释,不要输出其他内容。"
)
description_result = await emoji_manager_vlm.generate_response_for_image( description_result = await emoji_manager_vlm.generate_response_for_image(
prompt, prompt,
image_base64, image_base64,
@@ -878,10 +901,14 @@ class EmojiManager:
if "" in llm_response: if "" in llm_response:
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}") logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
return False, target_emoji return False, target_emoji
normalized_description = str(description).strip()
if not normalized_description:
logger.warning(f"[构建描述] 视觉模型返回空标签,跳过注册: {target_emoji.file_name}")
return False, target_emoji
hook_result = await _get_runtime_manager().invoke_hook( hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_description", "emoji.register.after_build_description",
emoji=_serialize_emoji_for_hook(target_emoji), emoji=_serialize_emoji_for_hook(target_emoji),
description=description, description=normalized_description,
image_format=image_format, image_format=image_format,
) )
if hook_result.aborted: if hook_result.aborted:
@@ -893,9 +920,14 @@ class EmojiManager:
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}") logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
return False, target_emoji return False, target_emoji
description = normalized_description normalized_emotions = _normalize_emoji_tag_text(normalized_description)
target_emoji.description = description if not normalized_emotions:
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}") logger.warning(f"[构建描述] Hook 返回标签为空,拒绝注册: {target_emoji.file_name}")
return False, target_emoji
target_emoji.description = ",".join(normalized_emotions)
target_emoji.emotion = normalized_emotions
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
return True, target_emoji return True, target_emoji
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]: async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
@@ -911,34 +943,11 @@ class EmojiManager:
logger.error("[构建情感标签] 表情包描述为空,无法构建情感标签") logger.error("[构建情感标签] 表情包描述为空,无法构建情感标签")
return False, target_emoji return False, target_emoji
# 获取Prompt emotions = _normalize_emoji_tag_text(target_emoji.description)
emotion_prompt_template = prompt_manager.get_prompt("emoji_content_analysis") if not emotions:
emotion_prompt_template.add_context("description", target_emoji.description) logger.warning(f"[构建情感标签] 表情包标签为空,跳过注册: {target_emoji.file_name}")
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
# 调用LLM生成情感标签
try:
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
emotion_prompt,
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
)
emotion_result = emotion_generation_result.response
except Exception as e:
logger.error(f"[构建情感标签] 调用模型生成情感标签时出错: {e}")
return False, target_emoji return False, target_emoji
if not emotion_result:
logger.warning(f"[构建情感标签] 情感标签结果为空,跳过注册: {target_emoji.file_name}")
return False, target_emoji
# 解析情感标签结果
emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:
emotions = random.sample(emotions, 3)
elif len(emotions) > 2:
emotions = random.sample(emotions, 2)
hook_result = await _get_runtime_manager().invoke_hook( hook_result = await _get_runtime_manager().invoke_hook(
"emoji.register.after_build_emotion", "emoji.register.after_build_emotion",
emoji=_serialize_emoji_for_hook(target_emoji), emoji=_serialize_emoji_for_hook(target_emoji),
@@ -951,7 +960,7 @@ class EmojiManager:
raw_emotions = hook_result.kwargs.get("emotions") raw_emotions = hook_result.kwargs.get("emotions")
if raw_emotions is not None: if raw_emotions is not None:
emotions = _normalize_string_list(raw_emotions) emotions = _normalize_emoji_tag_text(raw_emotions)
if not emotions: if not emotions:
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}") logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
return False, target_emoji return False, target_emoji
@@ -1100,18 +1109,13 @@ class EmojiManager:
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash): if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}") logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
return False return False
# 3. 构建描述 # 3. 构建描述(包含情绪标签)
desc_success, target_emoji = await self.build_emoji_description(target_emoji) desc_success, target_emoji = await self.build_emoji_description(target_emoji)
if not desc_success: if not desc_success:
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}") logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
return False return False
# 4. 构建情感标签
emo_success, target_emoji = await self.build_emoji_emotion(target_emoji)
if not emo_success:
logger.error(f"[注册表情包] 构建表情包情感标签失败: {file_full_path}")
return False
# 5. 检查容量并决定是否替换或者直接注册 # 4. 检查容量并决定是否替换或者直接注册
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace: if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包") logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
replaced = await self.replace_an_emoji_by_llm(target_emoji) replaced = await self.replace_an_emoji_by_llm(target_emoji)
@@ -1136,17 +1140,30 @@ class EmojiManager:
Args: Args:
text_emotion (str): 文本的情感标签 text_emotion (str): 文本的情感标签
Returns: Returns:
return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表 return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表
""" """
normalized_text_emotion = str(text_emotion or "").strip().lower()
if not normalized_text_emotion:
return []
similarity_list: List[Tuple[MaiEmoji, float]] = [] similarity_list: List[Tuple[MaiEmoji, float]] = []
for emoji in self.emojis: for emoji in self.emojis:
if not emoji.emotion: candidate_emotions = _get_emoji_emotions(emoji)
if not candidate_emotions:
continue continue
# 计算情感标签相似度,使用 Levenshtein 距离作为相似度指标
distance = Levenshtein.distance(text_emotion, emoji.emotion) emotion_similarities = [
max_len = max(len(text_emotion), len(emoji.emotion)) 1 - Levenshtein.distance(normalized_text_emotion, str(emotion).strip().lower()) / max(
similarity = 1 - (distance / max_len if max_len > 0 else 0) len(normalized_text_emotion),
similarity_list.append((emoji, similarity)) len(str(emotion).strip().lower()),
)
for emotion in candidate_emotions
if emotion
]
if not emotion_similarities:
continue
# 计算该表情包与输入标签的最接近匹配度
similarity_list.append((emoji, max(emotion_similarities)))
return similarity_list return similarity_list

View File

@@ -14,7 +14,12 @@ from src.common.logger import get_logger
from src.common.utils.utils_image import ImageUtils from src.common.utils.utils_image import ImageUtils
from src.services import send_service from src.services import send_service
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm from .emoji_manager import (
_normalize_emoji_tag_text,
_serialize_emoji_for_hook,
emoji_manager,
emoji_manager_emotion_judge_llm,
)
logger = get_logger("emoji_maisaka_tool") logger = get_logger("emoji_maisaka_tool")
@@ -113,8 +118,11 @@ def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
def _normalize_emotions(emoji: MaiEmoji) -> list[str]: def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
"""提取并清洗单个表情的情绪标签。""" """提取并清洗单个表情的情绪标签。"""
if emoji.description:
return [str(item).strip() for item in emoji.emotion if str(item).strip()] return _normalize_emoji_tag_text(emoji.description)
if emoji.emotion:
return _normalize_emoji_tag_text(emoji.emotion)
return []
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str: def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:

View File

@@ -177,7 +177,7 @@ class MaisakaReplyGenerator:
return f"{system_prompt}\n\n" + "\n\n".join(sections) return f"{system_prompt}\n\n" + "\n\n".join(sections)
def _build_reply_instruction(self) -> str: def _build_reply_instruction(self) -> str:
return "基于以上上下文,自然地继续回复。直接输出你要说的话,不需要额外解释" return "请自然地回复。请注意不要输出多余内容(包括不必要的前后缀冒号括号表情包at或 @等 ),只输出发言内容就好"
def _build_multimodal_user_message( def _build_multimodal_user_message(
self, self,

View File

@@ -152,22 +152,30 @@ class MaiEmoji(BaseImageDataModel):
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象") raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
obj = cls(db_record.full_path) obj = cls(db_record.full_path)
obj.file_hash = db_record.image_hash obj.file_hash = db_record.image_hash
obj.description = db_record.description description = db_record.description or db_record.emotion or ""
if db_record.emotion: obj.description = description
obj.emotion = db_record.emotion.split(",") normalized_tags = [
str(item).strip()
for item in str(description).replace("", ",").replace("", ",").replace("", ",").split(",")
if str(item).strip()
]
deduped_tags: List[str] = []
for item in normalized_tags:
if item not in deduped_tags:
deduped_tags.append(item)
obj.emotion = deduped_tags
obj.query_count = db_record.query_count obj.query_count = db_record.query_count
obj.last_used_time = db_record.last_used_time obj.last_used_time = db_record.last_used_time
obj.register_time = db_record.register_time obj.register_time = db_record.register_time
return obj return obj
def to_db_instance(self) -> Images: def to_db_instance(self) -> Images:
emotion_str = ",".join(self.emotion) if self.emotion else None
return Images( return Images(
image_hash=self.file_hash, image_hash=self.file_hash,
description=self.description, description=self.description,
full_path=str(self.full_path), full_path=str(self.full_path),
image_type=ImageType.EMOJI, image_type=ImageType.EMOJI,
emotion=emotion_str, emotion=None,
query_count=self.query_count, query_count=self.query_count,
last_used_time=self.last_used_time, last_used_time=self.last_used_time,
register_time=self.register_time, register_time=self.register_time,

View File

@@ -792,14 +792,18 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
image_hash = _normalize_required_text(row.get("emoji_hash")) image_hash = _normalize_required_text(row.get("emoji_hash"))
dedupe_key = (full_path, image_hash, "EMOJI") dedupe_key = (full_path, image_hash, "EMOJI")
if full_path and dedupe_key not in existing_keys: if full_path and dedupe_key not in existing_keys:
migrated_description = _normalize_required_text(row.get("description"))
migrated_emotion = _normalize_optional_text(row.get("emotion"))
if not migrated_description and migrated_emotion:
migrated_description = migrated_emotion
connection.execute( connection.execute(
insert_sql, insert_sql,
{ {
"image_hash": image_hash, "image_hash": image_hash,
"description": _normalize_required_text(row.get("description")), "description": migrated_description,
"full_path": full_path, "full_path": full_path,
"image_type": "EMOJI", "image_type": "EMOJI",
"emotion": _normalize_optional_text(row.get("emotion")), "emotion": None,
"query_count": _normalize_int(row.get("query_count"), default=0), "query_count": _normalize_int(row.get("query_count"), default=0),
"is_registered": _normalize_bool(row.get("is_registered"), default=False), "is_registered": _normalize_bool(row.get("is_registered"), default=False),
"is_banned": _normalize_bool(row.get("is_banned"), default=False), "is_banned": _normalize_bool(row.get("is_banned"), default=False),

View File

@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
MMC_VERSION: str = "1.0.0" MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.3.1" CONFIG_VERSION: str = "8.3.2"
MODEL_CONFIG_VERSION: str = "1.13.1" MODEL_CONFIG_VERSION: str = "1.13.1"
logger = get_logger("config") logger = get_logger("config")
@@ -115,9 +115,6 @@ class Config(ConfigBase):
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig) maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
"""maim_message配置类""" """maim_message配置类"""
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig, repr=False)
"""LPMM知识库配置类"""
webui: WebUIConfig = Field(default_factory=WebUIConfig) webui: WebUIConfig = Field(default_factory=WebUIConfig)
"""WebUI配置类""" """WebUI配置类"""

View File

@@ -255,8 +255,9 @@ class ChatConfig(ConfigBase):
) )
"""_wrap_私聊说话规则行为风格""" """_wrap_私聊说话规则行为风格"""
group_chat_prompt: str = Field( group_chat_prompt: str = Field(
default="不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的", default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
json_schema_extra={ json_schema_extra={
"x-widget": "textarea", "x-widget": "textarea",
"x-icon": "users", "x-icon": "users",
@@ -265,7 +266,7 @@ class ChatConfig(ConfigBase):
"""_wrap_群聊通用注意事项""" """_wrap_群聊通用注意事项"""
private_chat_prompts: str = Field( private_chat_prompts: str = Field(
default="", default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
json_schema_extra={ json_schema_extra={
"x-widget": "textarea", "x-widget": "textarea",
"x-icon": "user", "x-icon": "user",
@@ -1549,6 +1550,16 @@ class MaiSakaConfig(ConfigBase):
) )
"""每个入站消息的最大内部规划轮数""" """每个入站消息的最大内部规划轮数"""
planner_interrupt_max_consecutive_count: int = Field(
default=2,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "pause-circle",
},
)
"""Planner 连续被新消息打断的最大次数0 表示不启用打断"""
enable_memory_query_tool: bool = Field( enable_memory_query_tool: bool = Field(
default=True, default=True,
json_schema_extra={ json_schema_extra={

View File

@@ -249,9 +249,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
if message.role == RoleType.Tool: if message.role == RoleType.Tool:
if not message.tool_call_id: if not message.tool_call_id:
raise ValueError("Gemini 工具结果消息缺少 tool_call_id") raise ValueError("Gemini 工具结果消息缺少 tool_call_id")
tool_name = tool_name_by_call_id.get(message.tool_call_id) tool_name = (message.tool_name or tool_name_by_call_id.get(message.tool_call_id, "")).strip()
if not tool_name: if not tool_name:
raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称") raise ValueError(
f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称,"
"且消息中未携带 tool_name"
)
tool_name_by_call_id[message.tool_call_id] = tool_name
function_response_part = Part.from_function_response( function_response_part = Part.from_function_response(
name=tool_name, name=tool_name,
response=_normalize_function_response_payload(message), response=_normalize_function_response_payload(message),

View File

@@ -75,6 +75,7 @@ class Message:
role: RoleType role: RoleType
parts: List[MessagePart] = field(default_factory=list) parts: List[MessagePart] = field(default_factory=list)
tool_call_id: str | None = None tool_call_id: str | None = None
tool_name: str | None = None
tool_calls: List[ToolCall] | None = None tool_calls: List[ToolCall] | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
@@ -87,6 +88,8 @@ class Message:
raise ValueError("消息内容不能为空") raise ValueError("消息内容不能为空")
if self.role == RoleType.Tool and not self.tool_call_id: if self.role == RoleType.Tool and not self.tool_call_id:
raise ValueError("Tool 角色的工具调用 ID 不能为空") raise ValueError("Tool 角色的工具调用 ID 不能为空")
if self.tool_name and self.role != RoleType.Tool:
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
@property @property
def content(self) -> str | List[Tuple[str, str] | str]: def content(self) -> str | List[Tuple[str, str] | str]:
@@ -122,7 +125,7 @@ class Message:
""" """
return ( return (
f"Role: {self.role}, Parts: {self.parts}, " f"Role: {self.role}, Parts: {self.parts}, "
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}" f"Tool Call ID: {self.tool_call_id}, Tool Name: {self.tool_name}, Tool Calls: {self.tool_calls}"
) )
@@ -134,6 +137,7 @@ class MessageBuilder:
self.__role: RoleType = RoleType.User self.__role: RoleType = RoleType.User
self.__parts: List[MessagePart] = [] self.__parts: List[MessagePart] = []
self.__tool_call_id: str | None = None self.__tool_call_id: str | None = None
self.__tool_name: str | None = None
self.__tool_calls: List[ToolCall] | None = None self.__tool_calls: List[ToolCall] | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
@@ -247,6 +251,15 @@ class MessageBuilder:
""" """
return self.set_tool_call_id(tool_call_id) return self.set_tool_call_id(tool_call_id)
def set_tool_name(self, tool_name: str) -> "MessageBuilder":
"""设置 Tool 消息对应的工具名称。"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
if not tool_name:
raise ValueError("工具名称不能为空")
self.__tool_name = tool_name
return self
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder": def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
"""设置助手消息中的工具调用列表。 """设置助手消息中的工具调用列表。
@@ -276,5 +289,6 @@ class MessageBuilder:
role=self.__role, role=self.__role,
parts=list(self.__parts), parts=list(self.__parts),
tool_call_id=self.__tool_call_id, tool_call_id=self.__tool_call_id,
tool_name=self.__tool_name,
tool_calls=list(self.__tool_calls) if self.__tool_calls else None, tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
) )

View File

@@ -1,15 +1,16 @@
from datetime import datetime
import base64 import base64
import io import io
from PIL import Image from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import ModelUsage, ModelUser from src.common.database.database_model import ModelUsage, ModelUser
from src.common.logger import get_logger
from src.config.model_configs import ModelInfo from src.config.model_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord from .model_client.base_client import UsageRecord
from .payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
logger = get_logger("消息压缩工具") logger = get_logger("消息压缩工具")
@@ -131,25 +132,32 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
return base64_data return base64_data
compressed_messages = [] def rebuild_message_with_compressed_images(message: Message) -> Message:
for message in messages: """重建消息并压缩其中的图片,同时保留角色与工具元信息。"""
if isinstance(message.content, list): if not any(isinstance(part, ImageMessagePart) for part in message.parts):
# 检查content如有图片则压缩 return message
message_builder = MessageBuilder()
for content_item in message.content:
if isinstance(content_item, tuple):
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(content_item[1], target_size=img_target_size),
)
else:
message_builder.add_text_content(content_item)
compressed_messages.append(message_builder.build())
else:
compressed_messages.append(message)
return compressed_messages message_builder = MessageBuilder().set_role(message.role)
if message.role == RoleType.Assistant and message.tool_calls:
message_builder.set_tool_calls(message.tool_calls)
if message.role == RoleType.Tool and message.tool_call_id:
message_builder.set_tool_call_id(message.tool_call_id)
if message.role == RoleType.Tool and message.tool_name:
message_builder.set_tool_name(message.tool_name)
for message_part in message.parts:
if isinstance(message_part, ImageMessagePart):
message_builder.add_image_content(
message_part.image_format,
compress_base64_image(message_part.image_base64, target_size=img_target_size),
)
continue
if isinstance(message_part, TextMessagePart):
message_builder.add_text_content(message_part.text)
return message_builder.build()
return [rebuild_message_with_compressed_images(message) for message in messages]
class LLMUsageRecorder: class LLMUsageRecorder:

View File

@@ -0,0 +1,138 @@
"""Maisaka 聊天历史视觉占位刷新器。"""
from typing import Awaitable, Callable, Optional
from sqlmodel import select
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import EmojiComponent, ForwardNodeComponent, ImageComponent
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from .context_messages import LLMContextMessage, SessionBackedMessage
logger = get_logger("maisaka_chat_history_visual_refresher")
BuildHistoryMessage = Callable[[SessionMessage, str], Awaitable[Optional[LLMContextMessage]]]
BuildVisibleText = Callable[[SessionMessage], str]
async def refresh_chat_history_visual_placeholders(
*,
chat_history: list[LLMContextMessage],
build_history_message: BuildHistoryMessage,
build_visible_text: BuildVisibleText,
) -> int:
"""在进入新一轮规划前,尝试用已完成的识图结果刷新历史占位。"""
refreshed_count = 0
for index, history_message in enumerate(chat_history):
if not isinstance(history_message, SessionBackedMessage):
continue
original_message = history_message.original_message
if original_message is None:
continue
visual_components_updated = _refresh_pending_visual_components(original_message.raw_message.components)
if visual_components_updated:
await original_message.process(
enable_heavy_media_analysis=False,
enable_voice_transcription=False,
)
refreshed_visible_text = build_visible_text(original_message)
if not visual_components_updated and refreshed_visible_text == history_message.visible_text:
continue
rebuilt_history_message = await build_history_message(original_message, history_message.source_kind)
if rebuilt_history_message is None:
continue
chat_history[index] = rebuilt_history_message
refreshed_count += 1
return refreshed_count
def _refresh_pending_visual_components(components: list[object]) -> bool:
"""用缓存中的描述更新尚未补全文本的图片与表情组件。"""
refreshed = False
for component in components:
if isinstance(component, ImageComponent):
if _should_refresh_image_component(component):
image_description = _lookup_cached_image_description(component.binary_hash)
if image_description:
component.content = f"[图片:{image_description}]"
refreshed = True
continue
if isinstance(component, EmojiComponent):
if _should_refresh_emoji_component(component):
emoji_description = _lookup_cached_emoji_description(component.binary_hash)
if emoji_description:
component.content = f"[表情包: {emoji_description}]"
refreshed = True
continue
if not isinstance(component, ForwardNodeComponent):
continue
for forward_component in component.forward_components:
if _refresh_pending_visual_components(forward_component.content):
refreshed = True
return refreshed
def _should_refresh_image_component(component: ImageComponent) -> bool:
"""判断图片组件当前是否仍处于待补全文本的占位状态。"""
return not component.content or component.content == "[图片]"
def _should_refresh_emoji_component(component: EmojiComponent) -> bool:
"""判断表情组件当前是否仍处于待补全文本的占位状态。"""
return not component.content or component.content == "[表情包]"
def _lookup_cached_image_description(image_hash: str) -> str:
"""从数据库读取已完成的图片描述,不触发新的识图请求。"""
if not image_hash:
return ""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
if image_record := session.exec(statement).first():
if image_record.no_file_flag:
return ""
if image_record.vlm_processed and image_record.description:
return str(image_record.description).strip()
except Exception as exc:
logger.warning(f"读取图片缓存描述失败image_hash={image_hash}: {exc}")
return ""
def _lookup_cached_emoji_description(emoji_hash: str) -> str:
"""从数据库读取已完成的表情描述,不触发新的识别请求。"""
if not emoji_hash:
return ""
try:
with get_db_session() as session:
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
if image_record := session.exec(statement).first():
if image_record.no_file_flag or not image_record.description:
return ""
return str(image_record.description).strip()
except Exception as exc:
logger.warning(f"读取表情缓存描述失败emoji_hash={emoji_hash}: {exc}")
return ""

View File

@@ -196,6 +196,7 @@ def _build_message_from_sequence(
fallback_text: str, fallback_text: str,
*, *,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
tool_name: Optional[str] = None,
tool_calls: Optional[list[ToolCall]] = None, tool_calls: Optional[list[ToolCall]] = None,
) -> Optional[Message]: ) -> Optional[Message]:
"""根据消息片段构造统一 LLM 消息。""" """根据消息片段构造统一 LLM 消息。"""
@@ -204,6 +205,8 @@ def _build_message_from_sequence(
builder.set_tool_calls(tool_calls) builder.set_tool_calls(tool_calls)
if role == RoleType.Tool and tool_call_id: if role == RoleType.Tool and tool_call_id:
builder.add_tool_call(tool_call_id) builder.add_tool_call(tool_call_id)
if role == RoleType.Tool and tool_name:
builder.set_tool_name(tool_name)
has_content = False has_content = False
for component in message_sequence.components: for component in message_sequence.components:
@@ -481,4 +484,5 @@ class ToolResultMessage(LLMContextMessage):
message_sequence, message_sequence,
self.content, self.content,
tool_call_id=self.tool_call_id, tool_call_id=self.tool_call_id,
tool_name=self.tool_name,
) )

View File

@@ -66,11 +66,11 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str:
continue continue
if isinstance(component, EmojiComponent): if isinstance(component, EmojiComponent):
parts.append("[表情包]") parts.append(component.content or "[表情包]")
continue continue
if isinstance(component, ImageComponent): if isinstance(component, ImageComponent):
parts.append("[图片]") parts.append(component.content or "[图片]")
continue continue
if isinstance(component, ReplyComponent): if isinstance(component, ReplyComponent):

View File

@@ -24,6 +24,7 @@ from src.services import database_service as database_api
from .builtin_tool import get_action_tool_specs from .builtin_tool import get_action_tool_specs
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
from .builtin_tool import get_timing_tools from .builtin_tool import get_timing_tools
from .chat_history_visual_refresher import refresh_chat_history_visual_placeholders
from .builtin_tool.context import BuiltinToolRuntimeContext from .builtin_tool.context import BuiltinToolRuntimeContext
from .context_messages import ( from .context_messages import (
AssistantMessage, AssistantMessage,
@@ -103,16 +104,22 @@ class MaisakaReasoningEngine:
"""运行一轮可被新消息打断的主 planner 请求。""" """运行一轮可被新消息打断的主 planner 请求。"""
interrupt_flag = asyncio.Event() interrupt_flag = asyncio.Event()
self._runtime._planner_interrupt_flag = interrupt_flag interrupted = False
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
self._runtime._chat_loop_service.set_interrupt_flag(interrupt_flag) self._runtime._chat_loop_service.set_interrupt_flag(interrupt_flag)
try: try:
return await self._runtime._chat_loop_service.chat_loop_step( return await self._runtime._chat_loop_service.chat_loop_step(
self._runtime._chat_history, self._runtime._chat_history,
tool_definitions=tool_definitions, tool_definitions=tool_definitions,
) )
except ReqAbortException:
interrupted = True
raise
finally: finally:
if self._runtime._planner_interrupt_flag is interrupt_flag: self._runtime._unbind_planner_interrupt_flag(
self._runtime._planner_interrupt_flag = None interrupt_flag,
interrupted=interrupted,
)
self._runtime._chat_loop_service.set_interrupt_flag(None) self._runtime._chat_loop_service.set_interrupt_flag(None)
async def _run_interruptible_sub_agent( async def _run_interruptible_sub_agent(
@@ -125,7 +132,8 @@ class MaisakaReasoningEngine:
"""运行一轮可被新消息打断的临时子代理请求。""" """运行一轮可被新消息打断的临时子代理请求。"""
interrupt_flag = asyncio.Event() interrupt_flag = asyncio.Event()
self._runtime._planner_interrupt_flag = interrupt_flag interrupted = False
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
try: try:
return await self._runtime.run_sub_agent( return await self._runtime.run_sub_agent(
context_message_limit=context_message_limit, context_message_limit=context_message_limit,
@@ -136,9 +144,14 @@ class MaisakaReasoningEngine:
temperature=0.1, temperature=0.1,
tool_definitions=tool_definitions, tool_definitions=tool_definitions,
) )
except ReqAbortException:
interrupted = True
raise
finally: finally:
if self._runtime._planner_interrupt_flag is interrupt_flag: self._runtime._unbind_planner_interrupt_flag(
self._runtime._planner_interrupt_flag = None interrupt_flag,
interrupted=interrupted,
)
@staticmethod @staticmethod
def _build_timing_gate_fallback_prompt() -> str: def _build_timing_gate_fallback_prompt() -> str:
@@ -313,6 +326,14 @@ class MaisakaReasoningEngine:
) )
planner_started_at = 0.0 planner_started_at = 0.0
try: try:
visual_refresh_started_at = time.time()
refreshed_message_count = await self._refresh_chat_history_visual_placeholders()
cycle_detail.time_records["visual_refresh"] = time.time() - visual_refresh_started_at
if refreshed_message_count > 0:
logger.info(
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
)
timing_started_at = time.time() timing_started_at = time.time()
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message) timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
timing_duration_ms = (time.time() - timing_started_at) * 1000 timing_duration_ms = (time.time() - timing_started_at) * 1000
@@ -526,7 +547,12 @@ class MaisakaReasoningEngine:
timestamp=message.timestamp.timestamp(), timestamp=message.timestamp.timestamp(),
) )
async def _build_history_message(self, message: SessionMessage) -> Optional[LLMContextMessage]: async def _build_history_message(
self,
message: SessionMessage,
*,
source_kind: str = "user",
) -> Optional[LLMContextMessage]:
"""根据真实消息构造对应的上下文消息。""" """根据真实消息构造对应的上下文消息。"""
source_sequence = message.raw_message source_sequence = message.raw_message
@@ -537,7 +563,7 @@ class MaisakaReasoningEngine:
message, message,
planner_prefix=planner_prefix, planner_prefix=planner_prefix,
visible_text=visible_text, visible_text=visible_text,
source_kind="user", source_kind=source_kind,
) )
user_sequence = await self._build_message_sequence(message, planner_prefix=planner_prefix) user_sequence = await self._build_message_sequence(message, planner_prefix=planner_prefix)
@@ -548,7 +574,7 @@ class MaisakaReasoningEngine:
message, message,
raw_message=user_sequence, raw_message=user_sequence,
visible_text=visible_text, visible_text=visible_text,
source_kind="user", source_kind=source_kind,
) )
async def _build_message_sequence( async def _build_message_sequence(
@@ -601,6 +627,18 @@ class MaisakaReasoningEngine:
if isinstance(result, Exception): if isinstance(result, Exception):
logger.warning(f"{self._runtime.log_prefix} 回填图片或表情二进制数据失败Maisaka 将退化为文本占位: {result}") logger.warning(f"{self._runtime.log_prefix} 回填图片或表情二进制数据失败Maisaka 将退化为文本占位: {result}")
async def _refresh_chat_history_visual_placeholders(self) -> int:
"""在进入新一轮规划前,尝试用已完成的识图结果刷新历史占位。"""
return await refresh_chat_history_visual_placeholders(
chat_history=self._runtime._chat_history,
build_history_message=lambda message, source_kind: self._build_history_message(
message,
source_kind=source_kind,
),
build_visible_text=lambda message: self._build_legacy_visible_text(message, message.raw_message),
)
def _build_legacy_visible_text(self, message: SessionMessage, source_sequence: MessageSequence) -> str: def _build_legacy_visible_text(self, message: SessionMessage, source_sequence: MessageSequence) -> str:
user_info = message.message_info.user_info user_info = message.message_info.user_info
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id

View File

@@ -84,6 +84,12 @@ class MaisakaHeartFlowChatting:
self._wait_until: Optional[float] = None self._wait_until: Optional[float] = None
self._pending_wait_tool_call_id: Optional[str] = None self._pending_wait_tool_call_id: Optional[str] = None
self._planner_interrupt_flag: Optional[asyncio.Event] = None self._planner_interrupt_flag: Optional[asyncio.Event] = None
self._planner_interrupt_requested = False
self._planner_interrupt_consecutive_count = 0
self._planner_interrupt_max_consecutive_count = max(
0,
int(global_config.maisaka.planner_interrupt_max_consecutive_count),
)
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id) expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
self._enable_expression_use = expr_use self._enable_expression_use = expr_use
@@ -167,14 +173,51 @@ class MaisakaHeartFlowChatting:
if self._agent_state == self._STATE_RUNNING: if self._agent_state == self._STATE_RUNNING:
self._message_debounce_required = True self._message_debounce_required = True
if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None: if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None:
logger.info( if self._planner_interrupt_requested:
f"{self.log_prefix} 收到新消息,发起规划器打断; " logger.info(
f"消息编号={message.message_id} 缓存条数={len(self.message_cache)} " f"{self.log_prefix} 收到新消息,但当前请求已发起过一次规划器打断,"
f"时间戳={time.time():.3f}" f"本次不重复打断; 消息编号={message.message_id} "
) f"连续打断次数={self._planner_interrupt_consecutive_count}/"
self._planner_interrupt_flag.set() f"{self._planner_interrupt_max_consecutive_count}"
)
elif self._planner_interrupt_consecutive_count >= self._planner_interrupt_max_consecutive_count:
logger.info(
f"{self.log_prefix} 收到新消息,但已达到规划器连续打断上限,"
f"将等待当前请求自然完成; 消息编号={message.message_id} "
f"连续打断次数={self._planner_interrupt_consecutive_count}/"
f"{self._planner_interrupt_max_consecutive_count}"
)
else:
self._planner_interrupt_requested = True
self._planner_interrupt_consecutive_count += 1
logger.info(
f"{self.log_prefix} 收到新消息,发起规划器打断; "
f"消息编号={message.message_id} 缓存条数={len(self.message_cache)} "
f"时间戳={time.time():.3f} "
f"连续打断次数={self._planner_interrupt_consecutive_count}/"
f"{self._planner_interrupt_max_consecutive_count}"
)
self._planner_interrupt_flag.set()
self._new_message_event.set() self._new_message_event.set()
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
"""绑定当前可打断请求使用的中断标记。"""
self._planner_interrupt_flag = interrupt_flag
self._planner_interrupt_requested = False
def _unbind_planner_interrupt_flag(
self,
interrupt_flag: asyncio.Event,
*,
interrupted: bool,
) -> None:
"""解绑当前可打断请求的中断标记,并维护连续打断计数。"""
if self._planner_interrupt_flag is interrupt_flag:
self._planner_interrupt_flag = None
self._planner_interrupt_requested = False
if not interrupted:
self._planner_interrupt_consecutive_count = 0
def _ensure_background_tasks_running(self) -> None: def _ensure_background_tasks_running(self) -> None:
"""确保后台任务仍在运行,若崩溃则自动拉起。""" """确保后台任务仍在运行,若崩溃则自动拉起。"""
if not self._running: if not self._running:
@@ -513,7 +556,6 @@ class MaisakaHeartFlowChatting:
if not global_config.debug.show_maisaka_thinking: if not global_config.debug.show_maisaka_thinking:
return return
session_name = chat_manager.get_session_name(self.session_id) or self.session_id
body_lines = [ body_lines = [
f"上下文占用:{selected_history_count}/{self._max_context_size}", f"上下文占用:{selected_history_count}/{self._max_context_size}",
f"本次请求token消耗{self._format_token_count(prompt_tokens)}", f"本次请求token消耗{self._format_token_count(prompt_tokens)}",

View File

@@ -19,13 +19,48 @@ class RuntimeDataCapabilityMixin:
if not emoji_base64: if not emoji_base64:
return None return None
matched_emotion = emoji.emotion[0] if emoji.emotion else "" matched_emotion = RuntimeDataCapabilityMixin._normalize_emoji_tags(emoji)
return { return {
"base64": emoji_base64, "base64": emoji_base64,
"description": emoji.description, "description": emoji.description,
"emotion": matched_emotion, "emotion": matched_emotion,
} }
@staticmethod
def _normalize_emoji_tag_text(raw_value: Any) -> List[str]:
"""将文本或标签列表转为去重情绪标签列表。"""
if raw_value is None:
return []
if isinstance(raw_value, list):
values = raw_value
else:
values = [raw_value]
tags: List[str] = []
for value in values:
raw_text = str(value) if value is not None else ""
if not raw_text:
continue
tags.extend(
item.strip() for item in raw_text.replace("", ",").replace("", ",").replace("", ",").split(",")
)
deduped_tags: List[str] = []
for tag in tags:
tag_text = str(tag).strip()
if not tag_text:
continue
if tag_text not in deduped_tags:
deduped_tags.append(tag_text)
return deduped_tags
@staticmethod
def _normalize_emoji_tags(emoji: MaiEmoji) -> str:
"""从表情包对象提取兼容旧数据的情绪标签文本。"""
tags = RuntimeDataCapabilityMixin._normalize_emoji_tag_text(emoji.description or emoji.emotion)
return tags[0] if tags else ""
@staticmethod @staticmethod
def _build_emoji_temp_path() -> Path: def _build_emoji_temp_path() -> Path:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR from src.chat.emoji_system.emoji_manager import EMOJI_DIR
@@ -488,7 +523,16 @@ class RuntimeDataCapabilityMixin:
try: try:
from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
emotions = sorted({emotion for emoji in emoji_manager.emojis for emotion in emoji.emotion}) emotions = sorted(
{
str(emotion).strip()
for emoji in emoji_manager.emojis
for emotion in RuntimeDataCapabilityMixin._normalize_emoji_tag_text(
emoji.description or emoji.emotion
)
if str(emotion).strip()
}
)
return {"success": True, "emotions": emotions} return {"success": True, "emotions": emotions}
except Exception as e: except Exception as e:
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True) logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
@@ -568,7 +612,9 @@ class RuntimeDataCapabilityMixin:
"success": True, "success": True,
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}", "message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
"description": None if new_emoji is None else new_emoji.description, "description": None if new_emoji is None else new_emoji.description,
"emotions": None if new_emoji is None else new_emoji.emotion, "emotions": None
if new_emoji is None
else RuntimeDataCapabilityMixin._normalize_emoji_tag_text(new_emoji.description or new_emoji.emotion),
"replaced": replaced, "replaced": replaced,
"hash": None if new_emoji is None else new_emoji.file_hash, "hash": None if new_emoji is None else new_emoji.file_hash,
} }

View File

@@ -6,6 +6,7 @@ import io
import os import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
import re
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Cookie, HTTPException, Query from fastapi import APIRouter, Cookie, HTTPException, Query
@@ -55,6 +56,19 @@ from .support import (
router = APIRouter(prefix="/emoji", tags=["Emoji"]) router = APIRouter(prefix="/emoji", tags=["Emoji"])
def _normalize_emoji_description(description: str = "", emotion: str = "") -> str:
"""将上传参数中的描述/情绪标签归一化为可存储 description。"""
normalized_description = str(description or "").strip()
normalized_emotion = str(emotion or "").strip()
if normalized_description:
return normalized_description
if not normalized_emotion:
return ""
tags = re.split(r"[,,、;\s]+", normalized_emotion)
return ",".join(item.strip() for item in tags if item.strip())
@router.get("/list", response_model=EmojiListResponse) @router.get("/list", response_model=EmojiListResponse)
async def get_emoji_list( async def get_emoji_list(
page: int = Query(1, ge=1, description="页码"), page: int = Query(1, ge=1, description="页码"),
@@ -173,6 +187,14 @@ async def update_emoji(
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered: if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = datetime.now() update_data["register_time"] = datetime.now()
if "emotion" in update_data:
normalized_description = _normalize_emoji_description(
description=update_data.get("description", ""),
emotion=update_data.get("emotion", ""),
)
update_data["description"] = normalized_description
update_data.pop("emotion", None)
for field, value in update_data.items(): for field, value in update_data.items():
setattr(emoji, field, value) setattr(emoji, field, value)
@@ -543,7 +565,7 @@ async def upload_emoji(
_ = output_file.write(file_content) _ = output_file.write(file_content)
logger.info(f"表情包文件已保存: {full_path}") logger.info(f"表情包文件已保存: {full_path}")
emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else "" final_description = _normalize_emoji_description(description=description, emotion=emotion)
current_time = datetime.now() current_time = datetime.now()
with get_db_session() as session: with get_db_session() as session:
@@ -551,8 +573,8 @@ async def upload_emoji(
image_type=ImageType.EMOJI, image_type=ImageType.EMOJI,
full_path=full_path, full_path=full_path,
image_hash=emoji_hash, image_hash=emoji_hash,
description=description, description=final_description,
emotion=emotion_str or None, emotion=None,
query_count=0, query_count=0,
is_registered=is_registered, is_registered=is_registered,
is_banned=False, is_banned=False,
@@ -654,16 +676,16 @@ async def batch_upload_emoji(
with open(full_path, "wb") as output_file: with open(full_path, "wb") as output_file:
_ = output_file.write(file_content) _ = output_file.write(file_content)
emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else ""
current_time = datetime.now() current_time = datetime.now()
final_description = _normalize_emoji_description(emotion=emotion)
with get_db_session() as session: with get_db_session() as session:
emoji = Images( emoji = Images(
image_type=ImageType.EMOJI, image_type=ImageType.EMOJI,
full_path=full_path, full_path=full_path,
image_hash=emoji_hash, image_hash=emoji_hash,
description="", description=final_description,
emotion=emotion_str or None, emotion=None,
query_count=0, query_count=0,
is_registered=is_registered, is_registered=is_registered,
is_banned=False, is_banned=False,

View File

@@ -1,3 +1,4 @@
import re
from typing import Annotated, List, Optional from typing import Annotated, List, Optional
from fastapi import File, Form, UploadFile from fastapi import File, Form, UploadFile
@@ -5,15 +6,15 @@ from pydantic import BaseModel
from src.common.database.database_model import Images from src.common.database.database_model import Images
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")] EmojiFile = Annotated[UploadFile, File(description="表情包上传文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")] EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包上传文件")]
DescriptionForm = Annotated[str, Form(description="表情包描述")] DescriptionForm = Annotated[str, Form(description="表情包描述")]
EmotionForm = Annotated[str, Form(description="标签,多个用逗号分隔")] EmotionForm = Annotated[str, Form(description="标签,多个使用逗号分隔")]
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")] IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
class EmojiResponse(BaseModel): class EmojiResponse(BaseModel):
"""表情包响应""" """表情包响应结构"""
id: int id: int
full_path: str full_path: str
@@ -124,7 +125,20 @@ class ThumbnailPreheatResponse(BaseModel):
def emoji_to_response(image: Images) -> EmojiResponse: def emoji_to_response(image: Images) -> EmojiResponse:
"""将数据库表情包模型转换为响应对象。""" emotions: list[str] = []
if image.description:
emotions.extend(
item.strip() for item in re.split(r"[,,、;\s]+", image.description) if item and item.strip()
)
if not emotions and image.emotion:
emotions.extend(item.strip() for item in re.split(r"[,,、;\s]+", image.emotion) if item and item.strip())
deduped_emotions: list[str] = []
for item in emotions:
if item not in deduped_emotions:
deduped_emotions.append(item)
emotion = ",".join(deduped_emotions) if deduped_emotions else None
return EmojiResponse( return EmojiResponse(
id=image.id if image.id is not None else 0, id=image.id if image.id is not None else 0,
full_path=image.full_path, full_path=image.full_path,
@@ -133,7 +147,7 @@ def emoji_to_response(image: Images) -> EmojiResponse:
query_count=image.query_count, query_count=image.query_count,
is_registered=image.is_registered, is_registered=image.is_registered,
is_banned=image.is_banned, is_banned=image.is_banned,
emotion=image.emotion, emotion=emotion,
record_time=image.record_time.timestamp() if image.record_time else 0.0, record_time=image.record_time.timestamp() if image.record_time else 0.0,
register_time=image.register_time.timestamp() if image.register_time else None, register_time=image.register_time.timestamp() if image.register_time else None,
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None, last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,