feat:修复gemini tool问题,简化表情包识别,修复非多模态plan图片识别
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
@@ -129,24 +129,6 @@ def register_emoji_hook_specs(registry: HookSpecRegistry) -> List[HookSpec]:
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
HookSpec(
|
||||
name="emoji.register.after_build_emotion",
|
||||
description="表情包情绪标签生成完成后触发,可改写标签列表或拒绝本次注册。",
|
||||
parameters_schema=build_object_schema(
|
||||
{
|
||||
"emoji": emoji_schema,
|
||||
"description": {"type": "string", "description": "当前表情包描述。"},
|
||||
"emotions": {
|
||||
**string_array_schema,
|
||||
"description": "当前生成出的情绪标签列表。",
|
||||
},
|
||||
},
|
||||
required=["emoji", "description", "emotions"],
|
||||
),
|
||||
default_timeout_ms=5000,
|
||||
allow_abort=True,
|
||||
allow_kwargs_mutation=True,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -181,7 +163,7 @@ def _serialize_emoji_for_hook(emoji: Optional[MaiEmoji]) -> Optional[Dict[str, A
|
||||
"file_name": emoji.file_name,
|
||||
"full_path": str(emoji.full_path),
|
||||
"description": emoji.description,
|
||||
"emotions": [str(item).strip() for item in emoji.emotion if str(item).strip()],
|
||||
"emotions": [str(item).strip() for item in _normalize_emoji_tag_text(emoji.description or emoji.emotion)],
|
||||
"query_count": int(emoji.query_count),
|
||||
}
|
||||
|
||||
@@ -201,6 +183,39 @@ def _normalize_string_list(raw_values: Any) -> List[str]:
|
||||
return [str(item).strip() for item in raw_values if str(item).strip()]
|
||||
|
||||
|
||||
def _normalize_emoji_tag_text(raw_values: Any) -> List[str]:
|
||||
"""将文本或标签列表转为去重的情绪标签列表。"""
|
||||
if isinstance(raw_values, str):
|
||||
if not raw_values:
|
||||
return []
|
||||
parts = re.split(r"[,,、;;\s]+", raw_values.strip())
|
||||
normalized_tags = [str(part).strip() for part in parts if str(part).strip()]
|
||||
elif isinstance(raw_values, list):
|
||||
normalized_tags: List[str] = []
|
||||
for value in raw_values:
|
||||
normalized_tags.extend(_normalize_emoji_tag_text(value))
|
||||
else:
|
||||
return []
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
seen: Set[str] = set()
|
||||
for tag in normalized_tags:
|
||||
normalized_tag = tag.strip()
|
||||
if not normalized_tag:
|
||||
continue
|
||||
lowered = normalized_tag.lower()
|
||||
if lowered in seen:
|
||||
continue
|
||||
seen.add(lowered)
|
||||
deduped_tags.append(normalized_tag)
|
||||
return deduped_tags
|
||||
|
||||
|
||||
def _get_emoji_emotions(emoji: MaiEmoji) -> List[str]:
|
||||
"""获取兼容旧数据的表情包情绪标签。"""
|
||||
return _normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
||||
|
||||
|
||||
def _ensure_directories() -> None:
|
||||
"""确保表情包相关目录存在"""
|
||||
EMOJI_DIR.mkdir(parents=True, exist_ok=True)
|
||||
@@ -269,20 +284,23 @@ class EmojiManager:
|
||||
Exception: 如果在缓存表情包的过程中发生错误,则抛出异常
|
||||
"""
|
||||
# 先查找
|
||||
if emoji_hash is None and emoji_bytes is not None:
|
||||
if emoji_hash is None:
|
||||
if emoji_bytes is None:
|
||||
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
|
||||
emoji_hash = hashlib.sha256(emoji_bytes).hexdigest()
|
||||
else:
|
||||
emoji_hash = emoji_hash
|
||||
if not emoji_hash:
|
||||
raise ValueError("获取表情包描述失败: 既没有提供表情包字节数据,也没有提供表情包哈希值")
|
||||
|
||||
if emoji := self.get_emoji_by_hash(emoji_hash):
|
||||
return emoji.description, emoji.emotion or []
|
||||
return emoji.description, _normalize_emoji_tag_text(emoji.description or "")
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if result := session.exec(statement).first():
|
||||
return result.description, result.emotion.split(",") if result.emotion else []
|
||||
cached_description = result.description or result.emotion or ""
|
||||
cached_emotions = _normalize_emoji_tag_text(cached_description)
|
||||
return (
|
||||
cached_description,
|
||||
cached_emotions,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"从数据库查找表情包时出错: {e},将尝试构建表情包描述")
|
||||
|
||||
@@ -407,24 +425,19 @@ class EmojiManager:
|
||||
logger.error("Build emoji description failed")
|
||||
return None
|
||||
|
||||
success_emotion, new_emoji = await self.build_emoji_emotion(new_emoji)
|
||||
if not success_emotion:
|
||||
logger.error("Build emoji emotion labels failed")
|
||||
return None
|
||||
|
||||
# 情绪标签已在 build_emoji_description 内一次性生成,这里仅做兼容性兜底处理
|
||||
with get_db_session() as session:
|
||||
try:
|
||||
statement = select(Images).filter_by(image_hash=new_emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.full_path = str(new_emoji.full_path)
|
||||
image_record.description = new_emoji.description
|
||||
image_record.emotion = ",".join(new_emoji.emotion) if new_emoji.emotion else None
|
||||
image_record.no_file_flag = False
|
||||
image_record.is_banned = False
|
||||
session.add(image_record)
|
||||
except Exception as exc:
|
||||
logger.error(f"Update cached emoji description failed: {exc}")
|
||||
return new_emoji.description, new_emoji.emotion or []
|
||||
return new_emoji.description, _get_emoji_emotions(new_emoji)
|
||||
|
||||
def load_emojis_from_db(self) -> None:
|
||||
|
||||
@@ -512,7 +525,6 @@ class EmojiManager:
|
||||
existing_record.is_banned = False
|
||||
existing_record.full_path = str(emoji.full_path)
|
||||
existing_record.description = emoji.description
|
||||
existing_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||
existing_record.query_count = emoji.query_count
|
||||
existing_record.last_used_time = emoji.last_used_time
|
||||
existing_record.register_time = emoji.register_time
|
||||
@@ -639,7 +651,7 @@ class EmojiManager:
|
||||
statement = select(Images).filter_by(image_hash=emoji.file_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
image_record.description = emoji.description
|
||||
image_record.emotion = ",".join(emoji.emotion) if emoji.emotion else None
|
||||
image_record.emotion = None
|
||||
session.add(image_record)
|
||||
logger.info(f"[更新表情包] 成功更新表情包信息: {emoji.file_hash}")
|
||||
else:
|
||||
@@ -734,7 +746,11 @@ class EmojiManager:
|
||||
selected_emoji, similarity = random.choice(top_emojis)
|
||||
self.update_emoji_usage(selected_emoji)
|
||||
logger.info(
|
||||
f"[获取表情包] 为[{emotion_label}]选中表情包: {selected_emoji.file_name}({selected_emoji.emotion}),相似度: {similarity:.4f}"
|
||||
"[获取表情包] 为[%s]选中表情包: %s(%s),相似度: %.4f",
|
||||
emotion_label,
|
||||
selected_emoji.file_name,
|
||||
",".join(_get_emoji_emotions(selected_emoji)),
|
||||
similarity,
|
||||
)
|
||||
return selected_emoji
|
||||
|
||||
@@ -833,7 +849,11 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[构建描述] 转换 GIF 图片时出错: {e}")
|
||||
return False, target_emoji
|
||||
prompt: str = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,简短描述一下表情包表达的情感和内容,从互联网梗、meme的角度去分析,精简回答"
|
||||
prompt: str = (
|
||||
"这是一个动态图表情包,每一张图代表了动态图的一帧。"
|
||||
"请只返回该表情包常见的情绪/场景标签,最多 5 个,"
|
||||
"使用逗号分隔,标签可为中文或英文,不要附带解释。"
|
||||
)
|
||||
image_base64 = ImageUtils.image_bytes_to_base64(image_bytes)
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
@@ -843,7 +863,10 @@ class EmojiManager:
|
||||
)
|
||||
description = description_result.response
|
||||
else:
|
||||
prompt: str = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,简短描述细节,从互联网梗、meme的角度去分析,精简回答"
|
||||
prompt: str = (
|
||||
"这是一个表情包图片,请提取该表情主要表达的情绪或语气标签,"
|
||||
"最多 5 个,使用逗号分隔,返回纯文本标签列表,不要解释,不要输出其他内容。"
|
||||
)
|
||||
description_result = await emoji_manager_vlm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
@@ -878,10 +901,14 @@ class EmojiManager:
|
||||
if "否" in llm_response:
|
||||
logger.warning(f"[表情包审查] 表情包内容不符合要求,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
normalized_description = str(description).strip()
|
||||
if not normalized_description:
|
||||
logger.warning(f"[构建描述] 视觉模型返回空标签,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_description",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
description=description,
|
||||
description=normalized_description,
|
||||
image_format=image_format,
|
||||
)
|
||||
if hook_result.aborted:
|
||||
@@ -893,9 +920,14 @@ class EmojiManager:
|
||||
logger.warning(f"[构建描述] Hook 返回空描述,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
description = normalized_description
|
||||
target_emoji.description = description
|
||||
logger.info(f"[构建描述] 成功为表情包构建描述: {target_emoji.description}")
|
||||
normalized_emotions = _normalize_emoji_tag_text(normalized_description)
|
||||
if not normalized_emotions:
|
||||
logger.warning(f"[构建描述] Hook 返回标签为空,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
target_emoji.description = ",".join(normalized_emotions)
|
||||
target_emoji.emotion = normalized_emotions
|
||||
logger.info(f"[构建描述] 成功为表情包构建情绪标签: {target_emoji.description}")
|
||||
return True, target_emoji
|
||||
|
||||
async def build_emoji_emotion(self, target_emoji: MaiEmoji) -> Tuple[bool, MaiEmoji]:
|
||||
@@ -911,34 +943,11 @@ class EmojiManager:
|
||||
logger.error("[构建情感标签] 表情包描述为空,无法构建情感标签")
|
||||
return False, target_emoji
|
||||
|
||||
# 获取Prompt
|
||||
emotion_prompt_template = prompt_manager.get_prompt("emoji_content_analysis")
|
||||
emotion_prompt_template.add_context("description", target_emoji.description)
|
||||
emotion_prompt = await prompt_manager.render_prompt(emotion_prompt_template)
|
||||
# 调用LLM生成情感标签
|
||||
try:
|
||||
emotion_generation_result = await emoji_manager_emotion_judge_llm.generate_response(
|
||||
emotion_prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3, max_tokens=200),
|
||||
)
|
||||
emotion_result = emotion_generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"[构建情感标签] 调用模型生成情感标签时出错: {e}")
|
||||
emotions = _normalize_emoji_tag_text(target_emoji.description)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] 表情包标签为空,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
if not emotion_result:
|
||||
logger.warning(f"[构建情感标签] 情感标签结果为空,跳过注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
|
||||
# 解析情感标签结果
|
||||
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
|
||||
|
||||
# 根据情感标签数量随机选择 - 超过5个选3个,超过2个选2个
|
||||
if len(emotions) > 5:
|
||||
emotions = random.sample(emotions, 3)
|
||||
elif len(emotions) > 2:
|
||||
emotions = random.sample(emotions, 2)
|
||||
|
||||
hook_result = await _get_runtime_manager().invoke_hook(
|
||||
"emoji.register.after_build_emotion",
|
||||
emoji=_serialize_emoji_for_hook(target_emoji),
|
||||
@@ -951,7 +960,7 @@ class EmojiManager:
|
||||
|
||||
raw_emotions = hook_result.kwargs.get("emotions")
|
||||
if raw_emotions is not None:
|
||||
emotions = _normalize_string_list(raw_emotions)
|
||||
emotions = _normalize_emoji_tag_text(raw_emotions)
|
||||
if not emotions:
|
||||
logger.warning(f"[构建情感标签] Hook 返回空情绪标签,拒绝注册: {target_emoji.file_name}")
|
||||
return False, target_emoji
|
||||
@@ -1100,18 +1109,13 @@ class EmojiManager:
|
||||
if existing_emoji := self.get_emoji_by_hash(target_emoji.file_hash):
|
||||
logger.warning(f"[注册表情包] 表情包已存在,跳过注册: {existing_emoji.file_name}")
|
||||
return False
|
||||
# 3. 构建描述
|
||||
# 3. 构建描述(包含情绪标签)
|
||||
desc_success, target_emoji = await self.build_emoji_description(target_emoji)
|
||||
if not desc_success:
|
||||
logger.error(f"[注册表情包] 构建表情包描述失败: {file_full_path}")
|
||||
return False
|
||||
# 4. 构建情感标签
|
||||
emo_success, target_emoji = await self.build_emoji_emotion(target_emoji)
|
||||
if not emo_success:
|
||||
logger.error(f"[注册表情包] 构建表情包情感标签失败: {file_full_path}")
|
||||
return False
|
||||
|
||||
# 5. 检查容量并决定是否替换或者直接注册
|
||||
# 4. 检查容量并决定是否替换或者直接注册
|
||||
if self._emoji_num >= global_config.emoji.max_reg_num and global_config.emoji.do_replace:
|
||||
logger.warning(f"[注册表情包] 表情包数量已达上限{global_config.emoji.max_reg_num},尝试替换一个表情包")
|
||||
replaced = await self.replace_an_emoji_by_llm(target_emoji)
|
||||
@@ -1136,17 +1140,30 @@ class EmojiManager:
|
||||
Args:
|
||||
text_emotion (str): 文本的情感标签
|
||||
Returns:
|
||||
return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表
|
||||
return (List[Tuple[MaiEmoji, float]]): 返回表情包对象及其相似度的列表
|
||||
"""
|
||||
normalized_text_emotion = str(text_emotion or "").strip().lower()
|
||||
if not normalized_text_emotion:
|
||||
return []
|
||||
|
||||
similarity_list: List[Tuple[MaiEmoji, float]] = []
|
||||
for emoji in self.emojis:
|
||||
if not emoji.emotion:
|
||||
candidate_emotions = _get_emoji_emotions(emoji)
|
||||
if not candidate_emotions:
|
||||
continue
|
||||
# 计算情感标签相似度,使用 Levenshtein 距离作为相似度指标
|
||||
distance = Levenshtein.distance(text_emotion, emoji.emotion)
|
||||
max_len = max(len(text_emotion), len(emoji.emotion))
|
||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||
similarity_list.append((emoji, similarity))
|
||||
|
||||
emotion_similarities = [
|
||||
1 - Levenshtein.distance(normalized_text_emotion, str(emotion).strip().lower()) / max(
|
||||
len(normalized_text_emotion),
|
||||
len(str(emotion).strip().lower()),
|
||||
)
|
||||
for emotion in candidate_emotions
|
||||
if emotion
|
||||
]
|
||||
if not emotion_similarities:
|
||||
continue
|
||||
# 计算该表情包与输入标签的最接近匹配度
|
||||
similarity_list.append((emoji, max(emotion_similarities)))
|
||||
return similarity_list
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,12 @@ from src.common.logger import get_logger
|
||||
from src.common.utils.utils_image import ImageUtils
|
||||
from src.services import send_service
|
||||
|
||||
from .emoji_manager import _serialize_emoji_for_hook, emoji_manager, emoji_manager_emotion_judge_llm
|
||||
from .emoji_manager import (
|
||||
_normalize_emoji_tag_text,
|
||||
_serialize_emoji_for_hook,
|
||||
emoji_manager,
|
||||
emoji_manager_emotion_judge_llm,
|
||||
)
|
||||
|
||||
logger = get_logger("emoji_maisaka_tool")
|
||||
|
||||
@@ -113,8 +118,11 @@ def _resolve_selected_emoji(raw_value: Any) -> Optional[MaiEmoji]:
|
||||
|
||||
def _normalize_emotions(emoji: MaiEmoji) -> list[str]:
|
||||
"""提取并清洗单个表情的情绪标签。"""
|
||||
|
||||
return [str(item).strip() for item in emoji.emotion if str(item).strip()]
|
||||
if emoji.description:
|
||||
return _normalize_emoji_tag_text(emoji.description)
|
||||
if emoji.emotion:
|
||||
return _normalize_emoji_tag_text(emoji.emotion)
|
||||
return []
|
||||
|
||||
|
||||
def _build_recent_context_text(context_texts: Sequence[str], max_items: int = 5) -> str:
|
||||
|
||||
@@ -177,7 +177,7 @@ class MaisakaReplyGenerator:
|
||||
return f"{system_prompt}\n\n" + "\n\n".join(sections)
|
||||
|
||||
def _build_reply_instruction(self) -> str:
|
||||
return "请基于以上上下文,自然地继续回复。直接输出你要说的话,不需要额外解释。"
|
||||
return "请自然地回复。请注意不要输出多余内容(包括不必要的前后缀,冒号,括号,表情包,at或 @等 ),只输出发言内容就好。"
|
||||
|
||||
def _build_multimodal_user_message(
|
||||
self,
|
||||
|
||||
@@ -152,22 +152,30 @@ class MaiEmoji(BaseImageDataModel):
|
||||
raise ValueError(f"数据库记录 {db_record.image_hash} 标记为文件不存在,无法创建 MaiEmoji 对象")
|
||||
obj = cls(db_record.full_path)
|
||||
obj.file_hash = db_record.image_hash
|
||||
obj.description = db_record.description
|
||||
if db_record.emotion:
|
||||
obj.emotion = db_record.emotion.split(",")
|
||||
description = db_record.description or db_record.emotion or ""
|
||||
obj.description = description
|
||||
normalized_tags = [
|
||||
str(item).strip()
|
||||
for item in str(description).replace(",", ",").replace("、", ",").replace(";", ",").split(",")
|
||||
if str(item).strip()
|
||||
]
|
||||
deduped_tags: List[str] = []
|
||||
for item in normalized_tags:
|
||||
if item not in deduped_tags:
|
||||
deduped_tags.append(item)
|
||||
obj.emotion = deduped_tags
|
||||
obj.query_count = db_record.query_count
|
||||
obj.last_used_time = db_record.last_used_time
|
||||
obj.register_time = db_record.register_time
|
||||
return obj
|
||||
|
||||
def to_db_instance(self) -> Images:
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else None
|
||||
return Images(
|
||||
image_hash=self.file_hash,
|
||||
description=self.description,
|
||||
full_path=str(self.full_path),
|
||||
image_type=ImageType.EMOJI,
|
||||
emotion=emotion_str,
|
||||
emotion=None,
|
||||
query_count=self.query_count,
|
||||
last_used_time=self.last_used_time,
|
||||
register_time=self.register_time,
|
||||
|
||||
@@ -792,14 +792,18 @@ def _migrate_images(context: MigrationExecutionContext) -> int:
|
||||
image_hash = _normalize_required_text(row.get("emoji_hash"))
|
||||
dedupe_key = (full_path, image_hash, "EMOJI")
|
||||
if full_path and dedupe_key not in existing_keys:
|
||||
migrated_description = _normalize_required_text(row.get("description"))
|
||||
migrated_emotion = _normalize_optional_text(row.get("emotion"))
|
||||
if not migrated_description and migrated_emotion:
|
||||
migrated_description = migrated_emotion
|
||||
connection.execute(
|
||||
insert_sql,
|
||||
{
|
||||
"image_hash": image_hash,
|
||||
"description": _normalize_required_text(row.get("description")),
|
||||
"description": migrated_description,
|
||||
"full_path": full_path,
|
||||
"image_type": "EMOJI",
|
||||
"emotion": _normalize_optional_text(row.get("emotion")),
|
||||
"emotion": None,
|
||||
"query_count": _normalize_int(row.get("query_count"), default=0),
|
||||
"is_registered": _normalize_bool(row.get("is_registered"), default=False),
|
||||
"is_banned": _normalize_bool(row.get("is_banned"), default=False),
|
||||
|
||||
@@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config"
|
||||
BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute()
|
||||
MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.3.1"
|
||||
CONFIG_VERSION: str = "8.3.2"
|
||||
MODEL_CONFIG_VERSION: str = "1.13.1"
|
||||
|
||||
logger = get_logger("config")
|
||||
@@ -115,9 +115,6 @@ class Config(ConfigBase):
|
||||
maim_message: MaimMessageConfig = Field(default_factory=MaimMessageConfig)
|
||||
"""maim_message配置类"""
|
||||
|
||||
lpmm_knowledge: LPMMKnowledgeConfig = Field(default_factory=LPMMKnowledgeConfig, repr=False)
|
||||
"""LPMM知识库配置类"""
|
||||
|
||||
webui: WebUIConfig = Field(default_factory=WebUIConfig)
|
||||
"""WebUI配置类"""
|
||||
|
||||
|
||||
@@ -255,8 +255,9 @@ class ChatConfig(ConfigBase):
|
||||
)
|
||||
"""_wrap_私聊说话规则,行为风格"""
|
||||
|
||||
|
||||
group_chat_prompt: str = Field(
|
||||
default="不要回复的太频繁!控制回复的频率,不要每个人的消息都回复,只回复你感兴趣的或者主动提及你的。",
|
||||
default="你需要控制自己发言的频率,如果是一对一聊天,可以以较均匀的频率发言;如果用户较多,不要每句都回复,控制回复频率,不要回复的太频繁!控制回复的频率,不要每个人的消息都回复。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "users",
|
||||
@@ -265,7 +266,7 @@ class ChatConfig(ConfigBase):
|
||||
"""_wrap_群聊通用注意事项"""
|
||||
|
||||
private_chat_prompts: str = Field(
|
||||
default="",
|
||||
default="你需要控制自己发言的频率,可以以较均匀的频率发言。",
|
||||
json_schema_extra={
|
||||
"x-widget": "textarea",
|
||||
"x-icon": "user",
|
||||
@@ -1549,6 +1550,16 @@ class MaiSakaConfig(ConfigBase):
|
||||
)
|
||||
"""每个入站消息的最大内部规划轮数"""
|
||||
|
||||
planner_interrupt_max_consecutive_count: int = Field(
|
||||
default=2,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "pause-circle",
|
||||
},
|
||||
)
|
||||
"""Planner 连续被新消息打断的最大次数,0 表示不启用打断"""
|
||||
|
||||
enable_memory_query_tool: bool = Field(
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
|
||||
@@ -249,9 +249,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
if message.role == RoleType.Tool:
|
||||
if not message.tool_call_id:
|
||||
raise ValueError("Gemini 工具结果消息缺少 tool_call_id")
|
||||
tool_name = tool_name_by_call_id.get(message.tool_call_id)
|
||||
tool_name = (message.tool_name or tool_name_by_call_id.get(message.tool_call_id, "")).strip()
|
||||
if not tool_name:
|
||||
raise ValueError(f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称")
|
||||
raise ValueError(
|
||||
f"Gemini 无法根据 tool_call_id={message.tool_call_id} 找到对应的工具名称,"
|
||||
"且消息中未携带 tool_name"
|
||||
)
|
||||
tool_name_by_call_id[message.tool_call_id] = tool_name
|
||||
function_response_part = Part.from_function_response(
|
||||
name=tool_name,
|
||||
response=_normalize_function_response_payload(message),
|
||||
|
||||
@@ -75,6 +75,7 @@ class Message:
|
||||
role: RoleType
|
||||
parts: List[MessagePart] = field(default_factory=list)
|
||||
tool_call_id: str | None = None
|
||||
tool_name: str | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -87,6 +88,8 @@ class Message:
|
||||
raise ValueError("消息内容不能为空")
|
||||
if self.role == RoleType.Tool and not self.tool_call_id:
|
||||
raise ValueError("Tool 角色的工具调用 ID 不能为空")
|
||||
if self.tool_name and self.role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
|
||||
|
||||
@property
|
||||
def content(self) -> str | List[Tuple[str, str] | str]:
|
||||
@@ -122,7 +125,7 @@ class Message:
|
||||
"""
|
||||
return (
|
||||
f"Role: {self.role}, Parts: {self.parts}, "
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Name: {self.tool_name}, Tool Calls: {self.tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
@@ -134,6 +137,7 @@ class MessageBuilder:
|
||||
self.__role: RoleType = RoleType.User
|
||||
self.__parts: List[MessagePart] = []
|
||||
self.__tool_call_id: str | None = None
|
||||
self.__tool_name: str | None = None
|
||||
self.__tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||
@@ -247,6 +251,15 @@ class MessageBuilder:
|
||||
"""
|
||||
return self.set_tool_call_id(tool_call_id)
|
||||
|
||||
def set_tool_name(self, tool_name: str) -> "MessageBuilder":
|
||||
"""设置 Tool 消息对应的工具名称。"""
|
||||
if self.__role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具名称")
|
||||
if not tool_name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
self.__tool_name = tool_name
|
||||
return self
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
"""设置助手消息中的工具调用列表。
|
||||
|
||||
@@ -276,5 +289,6 @@ class MessageBuilder:
|
||||
role=self.__role,
|
||||
parts=list(self.__parts),
|
||||
tool_call_id=self.__tool_call_id,
|
||||
tool_name=self.__tool_name,
|
||||
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
from datetime import datetime
|
||||
import base64
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ModelUsage, ModelUser
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import ModelInfo
|
||||
from .payload_content.message import Message, MessageBuilder
|
||||
|
||||
from .model_client.base_client import UsageRecord
|
||||
from .payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
|
||||
logger = get_logger("消息压缩工具")
|
||||
|
||||
@@ -131,25 +132,32 @@ def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 *
|
||||
|
||||
return base64_data
|
||||
|
||||
compressed_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
# 检查content,如有图片则压缩
|
||||
message_builder = MessageBuilder()
|
||||
for content_item in message.content:
|
||||
if isinstance(content_item, tuple):
|
||||
# 图片,进行压缩
|
||||
message_builder.add_image_content(
|
||||
content_item[0],
|
||||
compress_base64_image(content_item[1], target_size=img_target_size),
|
||||
)
|
||||
else:
|
||||
message_builder.add_text_content(content_item)
|
||||
compressed_messages.append(message_builder.build())
|
||||
else:
|
||||
compressed_messages.append(message)
|
||||
def rebuild_message_with_compressed_images(message: Message) -> Message:
|
||||
"""重建消息并压缩其中的图片,同时保留角色与工具元信息。"""
|
||||
if not any(isinstance(part, ImageMessagePart) for part in message.parts):
|
||||
return message
|
||||
|
||||
return compressed_messages
|
||||
message_builder = MessageBuilder().set_role(message.role)
|
||||
if message.role == RoleType.Assistant and message.tool_calls:
|
||||
message_builder.set_tool_calls(message.tool_calls)
|
||||
if message.role == RoleType.Tool and message.tool_call_id:
|
||||
message_builder.set_tool_call_id(message.tool_call_id)
|
||||
if message.role == RoleType.Tool and message.tool_name:
|
||||
message_builder.set_tool_name(message.tool_name)
|
||||
|
||||
for message_part in message.parts:
|
||||
if isinstance(message_part, ImageMessagePart):
|
||||
message_builder.add_image_content(
|
||||
message_part.image_format,
|
||||
compress_base64_image(message_part.image_base64, target_size=img_target_size),
|
||||
)
|
||||
continue
|
||||
if isinstance(message_part, TextMessagePart):
|
||||
message_builder.add_text_content(message_part.text)
|
||||
|
||||
return message_builder.build()
|
||||
|
||||
return [rebuild_message_with_compressed_images(message) for message in messages]
|
||||
|
||||
|
||||
class LLMUsageRecorder:
|
||||
|
||||
138
src/maisaka/chat_history_visual_refresher.py
Normal file
138
src/maisaka/chat_history_visual_refresher.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Maisaka 聊天历史视觉占位刷新器。"""
|
||||
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.data_models.message_component_data_model import EmojiComponent, ForwardNodeComponent, ImageComponent
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .context_messages import LLMContextMessage, SessionBackedMessage
|
||||
|
||||
logger = get_logger("maisaka_chat_history_visual_refresher")
|
||||
|
||||
BuildHistoryMessage = Callable[[SessionMessage, str], Awaitable[Optional[LLMContextMessage]]]
|
||||
BuildVisibleText = Callable[[SessionMessage], str]
|
||||
|
||||
|
||||
async def refresh_chat_history_visual_placeholders(
|
||||
*,
|
||||
chat_history: list[LLMContextMessage],
|
||||
build_history_message: BuildHistoryMessage,
|
||||
build_visible_text: BuildVisibleText,
|
||||
) -> int:
|
||||
"""在进入新一轮规划前,尝试用已完成的识图结果刷新历史占位。"""
|
||||
|
||||
refreshed_count = 0
|
||||
for index, history_message in enumerate(chat_history):
|
||||
if not isinstance(history_message, SessionBackedMessage):
|
||||
continue
|
||||
|
||||
original_message = history_message.original_message
|
||||
if original_message is None:
|
||||
continue
|
||||
|
||||
visual_components_updated = _refresh_pending_visual_components(original_message.raw_message.components)
|
||||
if visual_components_updated:
|
||||
await original_message.process(
|
||||
enable_heavy_media_analysis=False,
|
||||
enable_voice_transcription=False,
|
||||
)
|
||||
|
||||
refreshed_visible_text = build_visible_text(original_message)
|
||||
if not visual_components_updated and refreshed_visible_text == history_message.visible_text:
|
||||
continue
|
||||
|
||||
rebuilt_history_message = await build_history_message(original_message, history_message.source_kind)
|
||||
if rebuilt_history_message is None:
|
||||
continue
|
||||
|
||||
chat_history[index] = rebuilt_history_message
|
||||
refreshed_count += 1
|
||||
|
||||
return refreshed_count
|
||||
|
||||
|
||||
def _refresh_pending_visual_components(components: list[object]) -> bool:
|
||||
"""用缓存中的描述更新尚未补全文本的图片与表情组件。"""
|
||||
|
||||
refreshed = False
|
||||
for component in components:
|
||||
if isinstance(component, ImageComponent):
|
||||
if _should_refresh_image_component(component):
|
||||
image_description = _lookup_cached_image_description(component.binary_hash)
|
||||
if image_description:
|
||||
component.content = f"[图片:{image_description}]"
|
||||
refreshed = True
|
||||
continue
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
if _should_refresh_emoji_component(component):
|
||||
emoji_description = _lookup_cached_emoji_description(component.binary_hash)
|
||||
if emoji_description:
|
||||
component.content = f"[表情包: {emoji_description}]"
|
||||
refreshed = True
|
||||
continue
|
||||
|
||||
if not isinstance(component, ForwardNodeComponent):
|
||||
continue
|
||||
|
||||
for forward_component in component.forward_components:
|
||||
if _refresh_pending_visual_components(forward_component.content):
|
||||
refreshed = True
|
||||
|
||||
return refreshed
|
||||
|
||||
|
||||
def _should_refresh_image_component(component: ImageComponent) -> bool:
|
||||
"""判断图片组件当前是否仍处于待补全文本的占位状态。"""
|
||||
|
||||
return not component.content or component.content == "[图片]"
|
||||
|
||||
|
||||
def _should_refresh_emoji_component(component: EmojiComponent) -> bool:
|
||||
"""判断表情组件当前是否仍处于待补全文本的占位状态。"""
|
||||
|
||||
return not component.content or component.content == "[表情包]"
|
||||
|
||||
|
||||
def _lookup_cached_image_description(image_hash: str) -> str:
|
||||
"""从数据库读取已完成的图片描述,不触发新的识图请求。"""
|
||||
|
||||
if not image_hash:
|
||||
return ""
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=image_hash, image_type=ImageType.IMAGE).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
if image_record.no_file_flag:
|
||||
return ""
|
||||
if image_record.vlm_processed and image_record.description:
|
||||
return str(image_record.description).strip()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取图片缓存描述失败,image_hash={image_hash}: {exc}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _lookup_cached_emoji_description(emoji_hash: str) -> str:
|
||||
"""从数据库读取已完成的表情描述,不触发新的识别请求。"""
|
||||
|
||||
if not emoji_hash:
|
||||
return ""
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).filter_by(image_hash=emoji_hash, image_type=ImageType.EMOJI).limit(1)
|
||||
if image_record := session.exec(statement).first():
|
||||
if image_record.no_file_flag or not image_record.description:
|
||||
return ""
|
||||
return str(image_record.description).strip()
|
||||
except Exception as exc:
|
||||
logger.warning(f"读取表情缓存描述失败,emoji_hash={emoji_hash}: {exc}")
|
||||
|
||||
return ""
|
||||
@@ -196,6 +196,7 @@ def _build_message_from_sequence(
|
||||
fallback_text: str,
|
||||
*,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_calls: Optional[list[ToolCall]] = None,
|
||||
) -> Optional[Message]:
|
||||
"""根据消息片段构造统一 LLM 消息。"""
|
||||
@@ -204,6 +205,8 @@ def _build_message_from_sequence(
|
||||
builder.set_tool_calls(tool_calls)
|
||||
if role == RoleType.Tool and tool_call_id:
|
||||
builder.add_tool_call(tool_call_id)
|
||||
if role == RoleType.Tool and tool_name:
|
||||
builder.set_tool_name(tool_name)
|
||||
|
||||
has_content = False
|
||||
for component in message_sequence.components:
|
||||
@@ -481,4 +484,5 @@ class ToolResultMessage(LLMContextMessage):
|
||||
message_sequence,
|
||||
self.content,
|
||||
tool_call_id=self.tool_call_id,
|
||||
tool_name=self.tool_name,
|
||||
)
|
||||
|
||||
@@ -66,11 +66,11 @@ def build_visible_text_from_sequence(message_sequence: MessageSequence) -> str:
|
||||
continue
|
||||
|
||||
if isinstance(component, EmojiComponent):
|
||||
parts.append("[表情包]")
|
||||
parts.append(component.content or "[表情包]")
|
||||
continue
|
||||
|
||||
if isinstance(component, ImageComponent):
|
||||
parts.append("[图片]")
|
||||
parts.append(component.content or "[图片]")
|
||||
continue
|
||||
|
||||
if isinstance(component, ReplyComponent):
|
||||
|
||||
@@ -24,6 +24,7 @@ from src.services import database_service as database_api
|
||||
from .builtin_tool import get_action_tool_specs
|
||||
from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers
|
||||
from .builtin_tool import get_timing_tools
|
||||
from .chat_history_visual_refresher import refresh_chat_history_visual_placeholders
|
||||
from .builtin_tool.context import BuiltinToolRuntimeContext
|
||||
from .context_messages import (
|
||||
AssistantMessage,
|
||||
@@ -103,16 +104,22 @@ class MaisakaReasoningEngine:
|
||||
"""运行一轮可被新消息打断的主 planner 请求。"""
|
||||
|
||||
interrupt_flag = asyncio.Event()
|
||||
self._runtime._planner_interrupt_flag = interrupt_flag
|
||||
interrupted = False
|
||||
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
|
||||
self._runtime._chat_loop_service.set_interrupt_flag(interrupt_flag)
|
||||
try:
|
||||
return await self._runtime._chat_loop_service.chat_loop_step(
|
||||
self._runtime._chat_history,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
interrupted = True
|
||||
raise
|
||||
finally:
|
||||
if self._runtime._planner_interrupt_flag is interrupt_flag:
|
||||
self._runtime._planner_interrupt_flag = None
|
||||
self._runtime._unbind_planner_interrupt_flag(
|
||||
interrupt_flag,
|
||||
interrupted=interrupted,
|
||||
)
|
||||
self._runtime._chat_loop_service.set_interrupt_flag(None)
|
||||
|
||||
async def _run_interruptible_sub_agent(
|
||||
@@ -125,7 +132,8 @@ class MaisakaReasoningEngine:
|
||||
"""运行一轮可被新消息打断的临时子代理请求。"""
|
||||
|
||||
interrupt_flag = asyncio.Event()
|
||||
self._runtime._planner_interrupt_flag = interrupt_flag
|
||||
interrupted = False
|
||||
self._runtime._bind_planner_interrupt_flag(interrupt_flag)
|
||||
try:
|
||||
return await self._runtime.run_sub_agent(
|
||||
context_message_limit=context_message_limit,
|
||||
@@ -136,9 +144,14 @@ class MaisakaReasoningEngine:
|
||||
temperature=0.1,
|
||||
tool_definitions=tool_definitions,
|
||||
)
|
||||
except ReqAbortException:
|
||||
interrupted = True
|
||||
raise
|
||||
finally:
|
||||
if self._runtime._planner_interrupt_flag is interrupt_flag:
|
||||
self._runtime._planner_interrupt_flag = None
|
||||
self._runtime._unbind_planner_interrupt_flag(
|
||||
interrupt_flag,
|
||||
interrupted=interrupted,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_timing_gate_fallback_prompt() -> str:
|
||||
@@ -313,6 +326,14 @@ class MaisakaReasoningEngine:
|
||||
)
|
||||
planner_started_at = 0.0
|
||||
try:
|
||||
visual_refresh_started_at = time.time()
|
||||
refreshed_message_count = await self._refresh_chat_history_visual_placeholders()
|
||||
cycle_detail.time_records["visual_refresh"] = time.time() - visual_refresh_started_at
|
||||
if refreshed_message_count > 0:
|
||||
logger.info(
|
||||
f"{self._runtime.log_prefix} 本轮思考前已刷新 {refreshed_message_count} 条视觉占位历史消息"
|
||||
)
|
||||
|
||||
timing_started_at = time.time()
|
||||
timing_action, timing_response, timing_tool_results = await self._run_timing_gate(anchor_message)
|
||||
timing_duration_ms = (time.time() - timing_started_at) * 1000
|
||||
@@ -526,7 +547,12 @@ class MaisakaReasoningEngine:
|
||||
timestamp=message.timestamp.timestamp(),
|
||||
)
|
||||
|
||||
async def _build_history_message(self, message: SessionMessage) -> Optional[LLMContextMessage]:
|
||||
async def _build_history_message(
|
||||
self,
|
||||
message: SessionMessage,
|
||||
*,
|
||||
source_kind: str = "user",
|
||||
) -> Optional[LLMContextMessage]:
|
||||
"""根据真实消息构造对应的上下文消息。"""
|
||||
|
||||
source_sequence = message.raw_message
|
||||
@@ -537,7 +563,7 @@ class MaisakaReasoningEngine:
|
||||
message,
|
||||
planner_prefix=planner_prefix,
|
||||
visible_text=visible_text,
|
||||
source_kind="user",
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
user_sequence = await self._build_message_sequence(message, planner_prefix=planner_prefix)
|
||||
@@ -548,7 +574,7 @@ class MaisakaReasoningEngine:
|
||||
message,
|
||||
raw_message=user_sequence,
|
||||
visible_text=visible_text,
|
||||
source_kind="user",
|
||||
source_kind=source_kind,
|
||||
)
|
||||
|
||||
async def _build_message_sequence(
|
||||
@@ -601,6 +627,18 @@ class MaisakaReasoningEngine:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"{self._runtime.log_prefix} 回填图片或表情二进制数据失败,Maisaka 将退化为文本占位: {result}")
|
||||
|
||||
async def _refresh_chat_history_visual_placeholders(self) -> int:
|
||||
"""在进入新一轮规划前,尝试用已完成的识图结果刷新历史占位。"""
|
||||
|
||||
return await refresh_chat_history_visual_placeholders(
|
||||
chat_history=self._runtime._chat_history,
|
||||
build_history_message=lambda message, source_kind: self._build_history_message(
|
||||
message,
|
||||
source_kind=source_kind,
|
||||
),
|
||||
build_visible_text=lambda message: self._build_legacy_visible_text(message, message.raw_message),
|
||||
)
|
||||
|
||||
def _build_legacy_visible_text(self, message: SessionMessage, source_sequence: MessageSequence) -> str:
|
||||
user_info = message.message_info.user_info
|
||||
speaker_name = user_info.user_cardname or user_info.user_nickname or user_info.user_id
|
||||
|
||||
@@ -84,6 +84,12 @@ class MaisakaHeartFlowChatting:
|
||||
self._wait_until: Optional[float] = None
|
||||
self._pending_wait_tool_call_id: Optional[str] = None
|
||||
self._planner_interrupt_flag: Optional[asyncio.Event] = None
|
||||
self._planner_interrupt_requested = False
|
||||
self._planner_interrupt_consecutive_count = 0
|
||||
self._planner_interrupt_max_consecutive_count = max(
|
||||
0,
|
||||
int(global_config.maisaka.planner_interrupt_max_consecutive_count),
|
||||
)
|
||||
|
||||
expr_use, jargon_learn, expr_learn = ExpressionConfigUtils.get_expression_config_for_chat(session_id)
|
||||
self._enable_expression_use = expr_use
|
||||
@@ -167,14 +173,51 @@ class MaisakaHeartFlowChatting:
|
||||
if self._agent_state == self._STATE_RUNNING:
|
||||
self._message_debounce_required = True
|
||||
if self._agent_state == self._STATE_RUNNING and self._planner_interrupt_flag is not None:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 收到新消息,发起规划器打断; "
|
||||
f"消息编号={message.message_id} 缓存条数={len(self.message_cache)} "
|
||||
f"时间戳={time.time():.3f}"
|
||||
)
|
||||
self._planner_interrupt_flag.set()
|
||||
if self._planner_interrupt_requested:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 收到新消息,但当前请求已发起过一次规划器打断,"
|
||||
f"本次不重复打断; 消息编号={message.message_id} "
|
||||
f"连续打断次数={self._planner_interrupt_consecutive_count}/"
|
||||
f"{self._planner_interrupt_max_consecutive_count}"
|
||||
)
|
||||
elif self._planner_interrupt_consecutive_count >= self._planner_interrupt_max_consecutive_count:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 收到新消息,但已达到规划器连续打断上限,"
|
||||
f"将等待当前请求自然完成; 消息编号={message.message_id} "
|
||||
f"连续打断次数={self._planner_interrupt_consecutive_count}/"
|
||||
f"{self._planner_interrupt_max_consecutive_count}"
|
||||
)
|
||||
else:
|
||||
self._planner_interrupt_requested = True
|
||||
self._planner_interrupt_consecutive_count += 1
|
||||
logger.info(
|
||||
f"{self.log_prefix} 收到新消息,发起规划器打断; "
|
||||
f"消息编号={message.message_id} 缓存条数={len(self.message_cache)} "
|
||||
f"时间戳={time.time():.3f} "
|
||||
f"连续打断次数={self._planner_interrupt_consecutive_count}/"
|
||||
f"{self._planner_interrupt_max_consecutive_count}"
|
||||
)
|
||||
self._planner_interrupt_flag.set()
|
||||
self._new_message_event.set()
|
||||
|
||||
def _bind_planner_interrupt_flag(self, interrupt_flag: asyncio.Event) -> None:
|
||||
"""绑定当前可打断请求使用的中断标记。"""
|
||||
self._planner_interrupt_flag = interrupt_flag
|
||||
self._planner_interrupt_requested = False
|
||||
|
||||
def _unbind_planner_interrupt_flag(
|
||||
self,
|
||||
interrupt_flag: asyncio.Event,
|
||||
*,
|
||||
interrupted: bool,
|
||||
) -> None:
|
||||
"""解绑当前可打断请求的中断标记,并维护连续打断计数。"""
|
||||
if self._planner_interrupt_flag is interrupt_flag:
|
||||
self._planner_interrupt_flag = None
|
||||
self._planner_interrupt_requested = False
|
||||
if not interrupted:
|
||||
self._planner_interrupt_consecutive_count = 0
|
||||
|
||||
def _ensure_background_tasks_running(self) -> None:
|
||||
"""确保后台任务仍在运行,若崩溃则自动拉起。"""
|
||||
if not self._running:
|
||||
@@ -513,7 +556,6 @@ class MaisakaHeartFlowChatting:
|
||||
if not global_config.debug.show_maisaka_thinking:
|
||||
return
|
||||
|
||||
session_name = chat_manager.get_session_name(self.session_id) or self.session_id
|
||||
body_lines = [
|
||||
f"上下文占用:{selected_history_count}/{self._max_context_size} 条",
|
||||
f"本次请求token消耗:{self._format_token_count(prompt_tokens)}",
|
||||
|
||||
@@ -19,13 +19,48 @@ class RuntimeDataCapabilityMixin:
|
||||
if not emoji_base64:
|
||||
return None
|
||||
|
||||
matched_emotion = emoji.emotion[0] if emoji.emotion else ""
|
||||
matched_emotion = RuntimeDataCapabilityMixin._normalize_emoji_tags(emoji)
|
||||
return {
|
||||
"base64": emoji_base64,
|
||||
"description": emoji.description,
|
||||
"emotion": matched_emotion,
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _normalize_emoji_tag_text(raw_value: Any) -> List[str]:
|
||||
"""将文本或标签列表转为去重情绪标签列表。"""
|
||||
if raw_value is None:
|
||||
return []
|
||||
if isinstance(raw_value, list):
|
||||
values = raw_value
|
||||
else:
|
||||
values = [raw_value]
|
||||
|
||||
tags: List[str] = []
|
||||
for value in values:
|
||||
raw_text = str(value) if value is not None else ""
|
||||
if not raw_text:
|
||||
continue
|
||||
tags.extend(
|
||||
item.strip() for item in raw_text.replace(",", ",").replace("、", ",").replace(";", ",").split(",")
|
||||
)
|
||||
|
||||
deduped_tags: List[str] = []
|
||||
for tag in tags:
|
||||
tag_text = str(tag).strip()
|
||||
if not tag_text:
|
||||
continue
|
||||
if tag_text not in deduped_tags:
|
||||
deduped_tags.append(tag_text)
|
||||
return deduped_tags
|
||||
|
||||
@staticmethod
|
||||
def _normalize_emoji_tags(emoji: MaiEmoji) -> str:
|
||||
"""从表情包对象提取兼容旧数据的情绪标签文本。"""
|
||||
tags = RuntimeDataCapabilityMixin._normalize_emoji_tag_text(emoji.description or emoji.emotion)
|
||||
return tags[0] if tags else ""
|
||||
|
||||
@staticmethod
|
||||
def _build_emoji_temp_path() -> Path:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
@@ -488,7 +523,16 @@ class RuntimeDataCapabilityMixin:
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
|
||||
emotions = sorted({emotion for emoji in emoji_manager.emojis for emotion in emoji.emotion})
|
||||
emotions = sorted(
|
||||
{
|
||||
str(emotion).strip()
|
||||
for emoji in emoji_manager.emojis
|
||||
for emotion in RuntimeDataCapabilityMixin._normalize_emoji_tag_text(
|
||||
emoji.description or emoji.emotion
|
||||
)
|
||||
if str(emotion).strip()
|
||||
}
|
||||
)
|
||||
return {"success": True, "emotions": emotions}
|
||||
except Exception as e:
|
||||
logger.error(f"[cap.emoji.get_emotions] 执行失败: {e}", exc_info=True)
|
||||
@@ -568,7 +612,9 @@ class RuntimeDataCapabilityMixin:
|
||||
"success": True,
|
||||
"message": f"表情包注册成功 {'(替换旧表情包)' if replaced else '(新增表情包)'}",
|
||||
"description": None if new_emoji is None else new_emoji.description,
|
||||
"emotions": None if new_emoji is None else new_emoji.emotion,
|
||||
"emotions": None
|
||||
if new_emoji is None
|
||||
else RuntimeDataCapabilityMixin._normalize_emoji_tag_text(new_emoji.description or new_emoji.emotion),
|
||||
"replaced": replaced,
|
||||
"hash": None if new_emoji is None else new_emoji.file_hash,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Cookie, HTTPException, Query
|
||||
@@ -55,6 +56,19 @@ from .support import (
|
||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
|
||||
def _normalize_emoji_description(description: str = "", emotion: str = "") -> str:
|
||||
"""将上传参数中的描述/情绪标签归一化为可存储 description。"""
|
||||
normalized_description = str(description or "").strip()
|
||||
normalized_emotion = str(emotion or "").strip()
|
||||
if normalized_description:
|
||||
return normalized_description
|
||||
if not normalized_emotion:
|
||||
return ""
|
||||
|
||||
tags = re.split(r"[,,、;;\s]+", normalized_emotion)
|
||||
return ",".join(item.strip() for item in tags if item.strip())
|
||||
|
||||
|
||||
@router.get("/list", response_model=EmojiListResponse)
|
||||
async def get_emoji_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
@@ -173,6 +187,14 @@ async def update_emoji(
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = datetime.now()
|
||||
|
||||
if "emotion" in update_data:
|
||||
normalized_description = _normalize_emoji_description(
|
||||
description=update_data.get("description", ""),
|
||||
emotion=update_data.get("emotion", ""),
|
||||
)
|
||||
update_data["description"] = normalized_description
|
||||
update_data.pop("emotion", None)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
|
||||
@@ -543,7 +565,7 @@ async def upload_emoji(
|
||||
_ = output_file.write(file_content)
|
||||
|
||||
logger.info(f"表情包文件已保存: {full_path}")
|
||||
emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else ""
|
||||
final_description = _normalize_emoji_description(description=description, emotion=emotion)
|
||||
|
||||
current_time = datetime.now()
|
||||
with get_db_session() as session:
|
||||
@@ -551,8 +573,8 @@ async def upload_emoji(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description=description,
|
||||
emotion=emotion_str or None,
|
||||
description=final_description,
|
||||
emotion=None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
@@ -654,16 +676,16 @@ async def batch_upload_emoji(
|
||||
with open(full_path, "wb") as output_file:
|
||||
_ = output_file.write(file_content)
|
||||
|
||||
emotion_str = ",".join(item.strip() for item in emotion.split(",") if item.strip()) if emotion else ""
|
||||
current_time = datetime.now()
|
||||
final_description = _normalize_emoji_description(emotion=emotion)
|
||||
|
||||
with get_db_session() as session:
|
||||
emoji = Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description="",
|
||||
emotion=emotion_str or None,
|
||||
description=final_description,
|
||||
emotion=None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import File, Form, UploadFile
|
||||
@@ -5,15 +6,15 @@ from pydantic import BaseModel
|
||||
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包上传文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包上传文件")]
|
||||
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||
EmotionForm = Annotated[str, Form(description="情绪标签,多个使用逗号分隔")]
|
||||
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||
|
||||
|
||||
class EmojiResponse(BaseModel):
|
||||
"""表情包响应"""
|
||||
"""表情包响应结构"""
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
@@ -124,7 +125,20 @@ class ThumbnailPreheatResponse(BaseModel):
|
||||
|
||||
|
||||
def emoji_to_response(image: Images) -> EmojiResponse:
|
||||
"""将数据库表情包模型转换为响应对象。"""
|
||||
emotions: list[str] = []
|
||||
if image.description:
|
||||
emotions.extend(
|
||||
item.strip() for item in re.split(r"[,,、;;\s]+", image.description) if item and item.strip()
|
||||
)
|
||||
if not emotions and image.emotion:
|
||||
emotions.extend(item.strip() for item in re.split(r"[,,、;;\s]+", image.emotion) if item and item.strip())
|
||||
|
||||
deduped_emotions: list[str] = []
|
||||
for item in emotions:
|
||||
if item not in deduped_emotions:
|
||||
deduped_emotions.append(item)
|
||||
emotion = ",".join(deduped_emotions) if deduped_emotions else None
|
||||
|
||||
return EmojiResponse(
|
||||
id=image.id if image.id is not None else 0,
|
||||
full_path=image.full_path,
|
||||
@@ -133,7 +147,7 @@ def emoji_to_response(image: Images) -> EmojiResponse:
|
||||
query_count=image.query_count,
|
||||
is_registered=image.is_registered,
|
||||
is_banned=image.is_banned,
|
||||
emotion=image.emotion,
|
||||
emotion=emotion,
|
||||
record_time=image.record_time.timestamp() if image.record_time else 0.0,
|
||||
register_time=image.register_time.timestamp() if image.register_time else None,
|
||||
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
|
||||
|
||||
Reference in New Issue
Block a user