feat:修复gemini tool问题,简化表情包识别,修复非多模态plan图片识别
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -129,24 +129,6 @@ def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -181,7 +163,7 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
@@ -201,6 +183,39 @@ def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
"""将文本或标签列表转为去重的情绪标签列表。"""
|
||||
if isinstance(raw_values, str):
|
||||
if not raw_values:
|
||||
return []
|
||||
parts = re.split(r"[,,、;;\s]+", raw_values.strip())
|
||||
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
||||
elif isinstance(raw_values, list):
|
||||
normalized_tags: List[str] = []
|
||||
for value in raw_values:
|
||||
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
||||
else:
|
||||
return []
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
for tag in normalized_tags:
|
||||
normalized_tag = tag.strip()
|
||||
if not normalized_tag:
|
||||
continue
|
||||
lowered = normalized_tag.lower()
|
||||
if lowered in seen:
|
||||
continue
|
||||
seen.add(lowered)
|
||||
deduped_tags.append(normalized_tag)
|
||||
return deduped_tags
|
||||
|
||||
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
|
||||
"""获取兼容旧数据的表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -269,20 +284,23 @@ class EmojiManager:
|
||||
Exception: 如果在缓存表情包的过程中发生错误,则抛出异常
|
||||
"""
|
||||
# 先查找
|
||||
if emoji_hash is None and emoji_bytes is not None:
|
||||
if emoji_hash is None:
|
||||
if emoji_bytes is None:
|
||||
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
|
||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||
else:
|
||||
emoji_hash = emoji_hash
|
||||
if not emoji_hash:
|
||||
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
|
||||
|
||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||
return emoji.description, emoji.emotion or []
|
||||
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
return result.description, result.emotion.split(",") if result.emotion else []
|
||||
cached_description = result.description or result.emotion or ""
|
||||
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
return (
|
||||
cached_description,
|
||||
cached_emotions,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
|
||||
|
||||
@@ -407,24 +425,19 @@ class EmojiManager:
|
||||
logger.error("Build emoji description failed")
|
||||
return None
|
||||
|
||||
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
|
||||
if not success_emotion:
|
||||
logger.error("Build emoji emotion labels failed")
|
||||
return None
|
||||
|
||||
# 情绪标签已在 build_emoji_description 内一次性生成,这里仅做兼容性兜底处理
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
statement = select(Images).filter_by(image_hash=new_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.full_path = str(new_emoji.full_path)
|
||||
image_record.description = new_emoji.description
|
||||
image_record.emotion = ",".join(new_emoji.emotion) if new_emoji.emotion else None
|
||||
image_record.no_file_flag = False
|
||||
image_record.is_banned = False
|
||||
session.add(image_record)
|
||||
except Exception as exc:
|
||||
logger.error(f"Update cached emoji description failed: {exc}")
|
||||
return new_emoji.description, new_emoji.emotion or []
|
||||
return new_emoji.description, _get_emoji_emotions(new_emoji)
|
||||
|
||||
def load_emojis_from_db(self) -> None:
|
||||
|
||||
@@ -512,7 +525,6 @@ class EmojiManager:
|
||||
existing_record.is_banned = False
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = emoji.description
|
||||
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||
existing_record.query_count = emoji.query_count
|
||||
existing_record.last_used_time = emoji.last_used_time
|
||||
existing_record.register_time = emoji.register_time
|
||||
@@ -639,7 +651,7 @@ class EmojiManager:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.description = emoji.description
|
||||
image_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||
image_record.emotion = None
|
||||
session.add(image_record)
|
||||
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
||||
else:
|
||||
@@ -734,7 +746,11 @@ class EmojiManager:
|
||||
selected_emoji, similarity = random.choice(top_emojis)
|
||||
self.update_emoji_usage(selected_emoji)
|
||||
logger.info(
|
||||
f"[获取表情包] 为[{emotion_label}]选中表情包: {selected_emoji.file_name}({selected_emoji.emotion}),相似度: {similarity:.4f}"
|
||||
"[获取表情包] 为[%s]选中表情包: %s(%s),相似度: %.4f",
|
||||
emotion_label,
|
||||
selected_emoji.file_name,
|
||||
",".join(_get_emoji_emotions(selected_emoji)),
|
||||
similarity,
|
||||
)
|
||||
return selected_emoji
|
||||
|
||||
@@ -833,7 +849,11 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}")
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
prompt: str = (
|
||||
"这是一个动态图表情包,每一张图代表了动态图的一帧。"
|
||||
"请只返回该表情包常见的情绪/场景标签,最多 5 个,"
|
||||
"使用逗号分隔,标签可为中文或英文,不要附带解释。"
|
||||
)
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
@@ -843,7 +863,10 @@ class EmojiManager:
|
||||
)
|
||||
description = description_result.response
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
prompt: str = (
|
||||
"这是一个表情包图片,请提取该表情主要表达的情绪或语气标签,"
|
||||
"最多 5 个,使用逗号分隔,返回纯文本标签列表,不要解释,不要输出其他内容。"
|
||||
)
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
@@ -878,10 +901,14 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
normalized_description = str(description).strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] 视觉模型返回空标签,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
description=normalized_description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
@@ -893,9 +920,14 @@ class EmojiManager:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
|
||||
if not normalized_emotions:
|
||||
logger.warning(f"[构建描述] Hook 返回标签为空,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
target_emoji.description = ",".join(normalized_emotions)
|
||||
target_emoji.emotion = normalized_emotions
|
||||
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
|
||||
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
@@ -911,34 +943,11 @@ class EmojiManager:
|
||||
logger.error("[构建情感标签] 表情包描述为空,无法构建情感标签")
|
||||
return False, target_emoji
|
||||
|
||||
# 获取Prompt
|
||||
emotion_prompt_template = prompt_manager.get_prompt("emoji_content_analysis")
|
||||
emotion_prompt_template.add_context("description", target_emoji.description)
|
||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||
# 调用LLM生成情感标签
|
||||
try:
|
||||
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emotion_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
|
||||
)
|
||||
emotion_result = emotion_generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"[构建情感标签] 调用模型生成情感标签时出错: {e}")
|
||||
emotions = _normalize_emoji_tag_text(target_emoji.description)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] 表情包标签为空,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
if not emotion_result:
|
||||
logger.warning(f"[构建情感标签] 情感标签结果为空,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
# 解析情感标签结果
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
emotions = random.sample(emotions, 3)
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
@@ -951,7 +960,7 @@ class EmojiManager:
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
emotions = _normalize_emoji_tag_text(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
@@ -1100,18 +1109,13 @@ class EmojiManager:
|
||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
||||
return False
|
||||
# 3. 构建描述
|
||||
# 3. 构建描述(包含情绪标签)
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
|
||||
return False
|
||||
# 4. 构建情感标签
|
||||
emo_success, target_emoji = await self.build_emoji_emotion(target_emoji)
|
||||
if not emo_success:
|
||||
logger.error(f"[注册表情包] 构建表情包情感标签失败: {file_full_path}")
|
||||
return False
|
||||
|
||||
# 5. 检查容量并决定是否替换或者直接注册
|
||||
# 4. 检查容量并决定是否替换或者直接注册
|
||||
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
||||
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
|
||||
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
||||
@@ -1136,17 +1140,30 @@ class EmojiManager:
|
||||
Args:
|
||||
text_emotion (str): 文本的情感标签
|
||||
Returns:
|
||||
return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表
|
||||
return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表
|
||||
"""
|
||||
normalized_text_emotion = str(text_emotion or "").strip().lower()
|
||||
if not normalized_text_emotion:
|
||||
return []
|
||||
|
||||
similarity_list: List[Tuple[MaiEmoji, float]] = []
|
||||
for emoji in self.emojis:
|
||||
if not emoji.emotion:
|
||||
candidate_emotions = _get_emoji_emotions(emoji)
|
||||
if not candidate_emotions:
|
||||
continue
|
||||
# 计算情感标签相似度,使用 Levenshtein 距离作为相似度指标
|
||||
distance = Levenshtein.distance(text_emotion, emoji.emotion)
|
||||
max_len = max(len(text_emotion), len(emoji.emotion))
|
||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||
similarity_list.append((emoji, similarity))
|
||||
|
||||
emotion_similarities = [
|
||||
1 - Levenshtein.distance(normalized_text_emotion, str(emotion).strip().lower()) / max(
|
||||
len(normalized_text_emotion),
|
||||
len(str(emotion).strip().lower()),
|
||||
)
|
||||
for emotion in candidate_emotions
|
||||
if emotion
|
||||
]
|
||||
if not emotion_similarities:
|
||||
continue
|
||||
# 计算该表情包与输入标签的最接近匹配度
|
||||
similarity_list.append((emoji, max(emotion_similarities)))
|
||||
return similarity_list
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,12 @@ from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
from .emoji_manager import (
|
||||
_normalize_emoji_tag_text,
|
||||
_serialize_emoji_for_hook,
|
||||
emoji_manager,
|
||||
emoji_manager_emotion_judge_llm,
|
||||
)
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
@@ -113,8 +118,11 @@ def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
|
||||
if emoji.description:
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
if emoji.emotion:
|
||||
return _normalize_emoji_tag_text(emoji.emotion)
|
||||
return []
|
||||
|
||||
|
||||
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
|
||||
|
||||
@@ -177,7 +177,7 @@ class MaisakaReplyGenerator:
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
return "请基于以上上下文,自然地继续回复。直接输出你要说的话,不需要额外解释。"
|
||||
return "请自然地回复。请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user