重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||
from rich.traceback import install
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
|
||||
from src.common.data_models.message_data_model import MessageAndActionModel
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord, Images
|
||||
from src.person_info.person_info import Person, get_person_id
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
|
||||
|
||||
@@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
|
||||
limit_mode: str = "latest",
|
||||
) -> List[DatabaseActionRecords]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord)
|
||||
.where((col(ActionRecord.session_id) == chat_id))
|
||||
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
actions.reverse()
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
|
||||
actions = list(query)
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
actions = list(reversed(actions))
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||
actions = session.exec(statement).all()
|
||||
return [
|
||||
DatabaseActionRecords(
|
||||
action_id=action.action_id,
|
||||
time=action.time,
|
||||
time=action.timestamp.timestamp(),
|
||||
action_name=action.action_name,
|
||||
action_data=action.action_data,
|
||||
action_done=action.action_done,
|
||||
action_build_into_prompt=action.action_build_into_prompt,
|
||||
action_prompt_display=action.action_prompt_display,
|
||||
chat_id=action.chat_id,
|
||||
chat_info_stream_id=action.chat_info_stream_id,
|
||||
chat_info_platform=action.chat_info_platform,
|
||||
action_reasoning=action.action_reasoning,
|
||||
action_data=action.action_data or "{}",
|
||||
action_done=True,
|
||||
action_build_into_prompt=bool(action.action_display_prompt),
|
||||
action_prompt_display=action.action_display_prompt or "",
|
||||
chat_id=action.session_id,
|
||||
chat_info_stream_id=action.session_id,
|
||||
chat_info_platform=global_config.bot.platform,
|
||||
action_reasoning=action.action_reasoning or "",
|
||||
)
|
||||
for action in actions
|
||||
]
|
||||
@@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||
query = ActionRecords.select().where(
|
||||
(ActionRecords.chat_id == chat_id)
|
||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(ActionRecord)
|
||||
.where((col(ActionRecord.session_id) == chat_id))
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
||||
)
|
||||
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in reversed(actions)]
|
||||
else: # earliest
|
||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
||||
else:
|
||||
query = query.order_by(ActionRecords.time.asc())
|
||||
if limit > 0:
|
||||
if limit_mode == "latest":
|
||||
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
actions = list(reversed(actions))
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||
actions = list(session.exec(statement).all())
|
||||
else:
|
||||
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||
actions = session.exec(statement).all()
|
||||
|
||||
actions = list(query)
|
||||
return [action.__data__ for action in actions]
|
||||
return [action.model_dump() for action in actions]
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_random(
|
||||
@@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
|
||||
|
||||
|
||||
def get_raw_msg_by_timestamp_with_users(
|
||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
||||
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||
|
||||
|
||||
def get_raw_msg_before_timestamp_with_users(
|
||||
timestamp: float, person_ids: list, limit: int = 0
|
||||
timestamp: float, person_ids: List[str], limit: int = 0
|
||||
) -> List[DatabaseMessages]:
|
||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
@@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
|
||||
|
||||
|
||||
def num_new_messages_since_with_users(
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
|
||||
) -> int:
|
||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||
if not person_ids: # 保持空列表检查
|
||||
@@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
|
||||
|
||||
|
||||
def _build_readable_messages_internal(
|
||||
messages: List[MessageAndActionModel],
|
||||
messages: List[DatabaseMessages],
|
||||
replace_bot_name: bool = True,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
@@ -413,7 +416,7 @@ def _build_readable_messages_internal(
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||
|
||||
def replace_pic_id(match: re.Match) -> str:
|
||||
def replace_pic_id(match: re.Match[str]) -> str:
|
||||
nonlocal current_pic_counter
|
||||
nonlocal pic_counter
|
||||
pic_id = match.group(1)
|
||||
@@ -421,7 +424,8 @@ def _build_readable_messages_internal(
|
||||
if pic_id not in pic_description_cache:
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -438,16 +442,11 @@ def _build_readable_messages_internal(
|
||||
|
||||
# 1: 获取发送者信息并提取消息组件
|
||||
for message in messages:
|
||||
if message.is_action_record:
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(message.display_message)
|
||||
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
|
||||
continue
|
||||
|
||||
platform = message.user_platform
|
||||
user_id = message.user_id
|
||||
user_nickname = message.user_nickname
|
||||
user_cardname = message.user_cardname
|
||||
user_info = message.user_info
|
||||
platform = user_info.platform
|
||||
user_id = user_info.user_id
|
||||
user_nickname = user_info.user_nickname
|
||||
user_cardname = user_info.user_cardname
|
||||
|
||||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
@@ -525,12 +524,12 @@ def _build_readable_messages_internal(
|
||||
if long_time_notice and prev_timestamp is not None:
|
||||
time_diff = timestamp - prev_timestamp
|
||||
time_diff_hours = time_diff / 3600
|
||||
|
||||
|
||||
# 检查是否跨天
|
||||
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
|
||||
is_cross_day = prev_date != current_date
|
||||
|
||||
|
||||
# 如果间隔大于8小时或跨天,插入提示
|
||||
if time_diff_hours > 8 or is_cross_day:
|
||||
# 格式化日期为中文格式:xxxx年xx月xx日(去掉前导零)
|
||||
@@ -542,20 +541,15 @@ def _build_readable_messages_internal(
|
||||
hours_str = f"{int(time_diff_hours)}h"
|
||||
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
|
||||
output_lines.append(notice)
|
||||
|
||||
|
||||
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
|
||||
|
||||
# 查找消息id(如果有)并构建id_prefix
|
||||
message_id = timestamp_to_id_mapping.get(timestamp, "")
|
||||
id_prefix = f"[{message_id}]" if message_id else ""
|
||||
|
||||
if is_action:
|
||||
# 对于动作记录,使用特殊格式
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {content}")
|
||||
else:
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
||||
|
||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||
output_lines.append("\n")
|
||||
prev_timestamp = timestamp
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
@@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读,请稍等"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception:
|
||||
@@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
|
||||
messages,
|
||||
replace_bot_name,
|
||||
timestamp_mode,
|
||||
truncate,
|
||||
@@ -754,7 +749,7 @@ def build_readable_messages(
|
||||
filtered_messages = []
|
||||
for msg in messages:
|
||||
# 获取消息内容
|
||||
content = msg.processed_plain_text
|
||||
content = msg.processed_plain_text or ""
|
||||
# 移除表情包
|
||||
emoji_pattern = r"\[表情包:[^\]]+\]"
|
||||
content = re.sub(emoji_pattern, "", content)
|
||||
@@ -765,17 +760,14 @@ def build_readable_messages(
|
||||
|
||||
messages = filtered_messages
|
||||
|
||||
copy_messages: List[MessageAndActionModel] = []
|
||||
copy_messages: List[DatabaseMessages] = []
|
||||
for msg in messages:
|
||||
if remove_emoji_stickers:
|
||||
# 创建 MessageAndActionModel 但移除表情包
|
||||
model = MessageAndActionModel.from_DatabaseMessages(msg)
|
||||
# 移除表情包
|
||||
if model.processed_plain_text:
|
||||
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
|
||||
copy_messages.append(model)
|
||||
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
|
||||
copy_messages.append(msg)
|
||||
else:
|
||||
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
|
||||
copy_messages.append(msg)
|
||||
|
||||
if show_actions and copy_messages:
|
||||
# 获取所有消息的时间范围
|
||||
@@ -786,40 +778,45 @@ def build_readable_messages(
|
||||
chat_id = messages[0].chat_id if messages else None
|
||||
|
||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||
actions_in_range = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
||||
)
|
||||
.order_by(ActionRecords.time)
|
||||
)
|
||||
with get_db_session() as session:
|
||||
actions_in_range = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
|
||||
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == chat_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
).all()
|
||||
|
||||
# 获取最新消息之后的第一个动作记录
|
||||
action_after_latest = (
|
||||
ActionRecords.select()
|
||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
||||
.order_by(ActionRecords.time)
|
||||
.limit(1)
|
||||
)
|
||||
with get_db_session() as session:
|
||||
action_after_latest = session.exec(
|
||||
select(ActionRecord)
|
||||
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
|
||||
.where(col(ActionRecord.session_id) == chat_id)
|
||||
.order_by(col(ActionRecord.timestamp))
|
||||
.limit(1)
|
||||
).all()
|
||||
|
||||
# 合并两部分动作记录
|
||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
||||
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
|
||||
|
||||
# 将动作记录转换为消息格式
|
||||
for action in actions:
|
||||
# 只有当build_into_prompt为True时才添加动作记录
|
||||
if action.action_build_into_prompt:
|
||||
action_msg = MessageAndActionModel(
|
||||
time=float(action.time), # type: ignore
|
||||
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
|
||||
user_platform=global_config.bot.platform, # 使用机器人的平台
|
||||
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
|
||||
user_cardname="", # 机器人没有群名片
|
||||
processed_plain_text=f"{action.action_prompt_display}",
|
||||
display_message=f"{action.action_prompt_display}",
|
||||
chat_info_platform=str(action.chat_info_platform),
|
||||
is_action_record=True, # 添加标识字段
|
||||
action_name=str(action.action_name), # 保存动作名称
|
||||
action_display_prompt = action.action_display_prompt or ""
|
||||
if action_display_prompt:
|
||||
action_msg = DatabaseMessages(
|
||||
message_id=f"action_{action.action_id}",
|
||||
time=float(action.timestamp.timestamp()),
|
||||
chat_id=chat_id or "",
|
||||
processed_plain_text=action_display_prompt,
|
||||
display_message=action_display_prompt,
|
||||
user_platform=global_config.bot.platform,
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
user_cardname="",
|
||||
chat_info_platform=str(global_config.bot.platform),
|
||||
chat_info_stream_id=chat_id or "",
|
||||
)
|
||||
copy_messages.append(action_msg)
|
||||
|
||||
@@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
person_ids_set = set() # 使用集合来自动去重
|
||||
|
||||
for msg in messages:
|
||||
platform: str = msg.get("user_platform") # type: ignore
|
||||
user_id: str = msg.get("user_id") # type: ignore
|
||||
platform = msg.get("user_platform") or ""
|
||||
user_id = msg.get("user_id") or ""
|
||||
|
||||
# 检查必要信息是否存在 且 不是机器人自己
|
||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||
continue
|
||||
|
||||
# 添加空值检查,防止 platform 为 None 时出错
|
||||
if platform is None:
|
||||
platform = "unknown"
|
||||
|
||||
if person_id := get_person_id(platform, user_id):
|
||||
person_ids_set.add(person_id)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,18 +1,21 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
@@ -38,11 +41,7 @@ class ImageManager:
|
||||
self._initialized = True
|
||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或表创建失败: {e}")
|
||||
get_db_session()
|
||||
|
||||
try:
|
||||
self._cleanup_invalid_descriptions()
|
||||
@@ -72,10 +71,12 @@ class ImageManager:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
record = ImageDescriptions.get_or_none(
|
||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
||||
)
|
||||
return record.description if record else None
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
return record.description if record else None
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||
return None
|
||||
@@ -90,15 +91,27 @@ class ImageManager:
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
defaults = {"description": description, "timestamp": current_timestamp}
|
||||
desc_obj, created = ImageDescriptions.get_or_create(
|
||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
||||
)
|
||||
if not created: # 如果记录已存在,则更新
|
||||
desc_obj.description = description
|
||||
desc_obj.timestamp = current_timestamp
|
||||
desc_obj.save()
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||
)
|
||||
record = session.exec(statement).first()
|
||||
if record:
|
||||
record.description = description
|
||||
session.add(record)
|
||||
return
|
||||
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=description,
|
||||
full_path="",
|
||||
image_type=ImageType(description_type),
|
||||
query_count=0,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(new_record)
|
||||
except Exception as e:
|
||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||
|
||||
@@ -107,20 +120,18 @@ class ImageManager:
|
||||
"""清理数据库中 description 为空或为 'None' 的记录"""
|
||||
invalid_values = ["", "None"]
|
||||
|
||||
# 清理 Images 表
|
||||
deleted_images = (
|
||||
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images)
|
||||
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
|
||||
.limit(1000)
|
||||
)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
session.delete(record)
|
||||
|
||||
# 清理 ImageDescriptions 表
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete()
|
||||
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
|
||||
.execute()
|
||||
)
|
||||
|
||||
if deleted_images or deleted_descriptions:
|
||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
||||
if records:
|
||||
logger.info(f"[清理完成] 删除 Images: {len(records)} 条")
|
||||
else:
|
||||
logger.info("[清理完成] 未发现无效描述记录")
|
||||
|
||||
@@ -128,19 +139,15 @@ class ImageManager:
|
||||
def _cleanup_emoji_from_image_descriptions():
|
||||
"""清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)"""
|
||||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
session.delete(record)
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
total_deleted = len(records)
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
|
||||
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
|
||||
f"共删除 {total_deleted} 条记录"
|
||||
)
|
||||
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
|
||||
else:
|
||||
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
|
||||
except Exception as e:
|
||||
@@ -148,14 +155,14 @@ class ImageManager:
|
||||
raise
|
||||
|
||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
emoji_manager = emoji_manager_instance
|
||||
if isinstance(image_base64, str):
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
if not emoji:
|
||||
return "[表情包:未知]"
|
||||
emotion_list = emoji.emotion
|
||||
@@ -175,14 +182,14 @@ class ImageManager:
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
# 检查是否已存在该表情包(通过哈希值)
|
||||
emoji_manager = get_emoji_manager()
|
||||
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
emoji_manager = emoji_manager_instance
|
||||
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
if existing_emoji:
|
||||
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
||||
return
|
||||
@@ -212,14 +219,15 @@ class ImageManager:
|
||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
|
||||
# 优先使用EmojiManager查询已注册表情包的描述
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
||||
emoji_manager = emoji_manager_instance
|
||||
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||
tags = emoji.emotion if emoji else None
|
||||
if tags:
|
||||
tag_str = ",".join(tags)
|
||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
|
||||
@@ -227,29 +235,26 @@ class ImageManager:
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||
|
||||
# 查询EmojiDescriptionCache表的缓存(包含描述和情感标签)
|
||||
try:
|
||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
if cache_record.emotion:
|
||||
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.emotion}]"
|
||||
elif cache_record.description:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
return result_text
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
||||
logger.debug(f"查询Images缓存时出错: {e}")
|
||||
|
||||
# === 二步走识别流程 ===
|
||||
|
||||
@@ -309,33 +314,42 @@ class ImageManager:
|
||||
|
||||
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||
|
||||
# 再次检查缓存(防止并发情况下其他线程已经保存)
|
||||
try:
|
||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
if cache_record and cache_record.emotion_tags:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
|
||||
return f"[表情包:{cache_record.emotion_tags}]"
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record and cache_record.emotion:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
|
||||
return f"[表情包:{cache_record.emotion}]"
|
||||
except Exception as e:
|
||||
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
|
||||
logger.debug(f"再次查询Images缓存时出错: {e}")
|
||||
|
||||
# 保存识别出的详细描述和情感标签到 emoji_description_cache
|
||||
try:
|
||||
current_timestamp = time.time()
|
||||
cache_record, created = EmojiDescriptionCache.get_or_create(
|
||||
emoji_hash=image_hash,
|
||||
defaults={
|
||||
"description": detailed_description,
|
||||
"emotion_tags": final_emotion,
|
||||
"timestamp": current_timestamp,
|
||||
},
|
||||
)
|
||||
if not created:
|
||||
# 更新已有记录
|
||||
cache_record.description = detailed_description
|
||||
cache_record.emotion_tags = final_emotion
|
||||
cache_record.timestamp = current_timestamp
|
||||
cache_record.save()
|
||||
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||
)
|
||||
cache_record = session.exec(statement).first()
|
||||
if cache_record:
|
||||
cache_record.description = detailed_description
|
||||
cache_record.emotion = final_emotion
|
||||
session.add(cache_record)
|
||||
else:
|
||||
cache_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=detailed_description,
|
||||
full_path="",
|
||||
image_type=ImageType.EMOJI,
|
||||
emotion=final_emotion,
|
||||
query_count=0,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(cache_record)
|
||||
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
||||
|
||||
@@ -358,14 +372,13 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(col(Images.image_hash) == image_hash)
|
||||
existing_image = session.exec(statement).first()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
existing_image.count += 1
|
||||
else:
|
||||
existing_image.count = 1
|
||||
existing_image.save()
|
||||
existing_image.query_count += 1
|
||||
with get_db_session() as session:
|
||||
session.add(existing_image)
|
||||
|
||||
# 如果已有描述,直接返回
|
||||
if existing_image.description:
|
||||
@@ -377,7 +390,7 @@ class ImageManager:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
prompt = global_config.personality.visual_style
|
||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||
description, _ = await self.vlm.generate_response_for_image(
|
||||
@@ -402,26 +415,27 @@ class ImageManager:
|
||||
|
||||
# 保存到数据库,补充缺失字段
|
||||
if existing_image:
|
||||
existing_image.path = file_path
|
||||
existing_image.full_path = file_path
|
||||
existing_image.description = description
|
||||
existing_image.timestamp = current_timestamp
|
||||
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = True
|
||||
existing_image.save()
|
||||
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
|
||||
existing_image.vlm_processed = True
|
||||
with get_db_session() as session:
|
||||
session.add(existing_image)
|
||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||
else:
|
||||
Images.create(
|
||||
image_id=str(uuid.uuid4()),
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
description=description,
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=True,
|
||||
count=1,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description=description,
|
||||
full_path=file_path,
|
||||
image_type=ImageType.IMAGE,
|
||||
query_count=1,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime.fromtimestamp(current_timestamp),
|
||||
vlm_processed=True,
|
||||
)
|
||||
session.add(new_record)
|
||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||
@@ -575,30 +589,17 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
not hasattr(existing_image, "image_id")
|
||||
or not existing_image.image_id
|
||||
or not hasattr(existing_image, "count")
|
||||
or existing_image.count is None
|
||||
or not hasattr(existing_image, "vlm_processed")
|
||||
or existing_image.vlm_processed is None
|
||||
):
|
||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
||||
if not existing_image.image_id:
|
||||
existing_image.image_id = str(uuid.uuid4())
|
||||
if existing_image.count is None:
|
||||
existing_image.count = 0
|
||||
if existing_image.vlm_processed is None:
|
||||
existing_image.vlm_processed = False
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
|
||||
)
|
||||
existing_image = session.exec(statement).first()
|
||||
if existing_image:
|
||||
existing_image.query_count += 1
|
||||
session.add(existing_image)
|
||||
return str(existing_image.id), f"[picid:{existing_image.id}]"
|
||||
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
@@ -612,15 +613,19 @@ class ImageManager:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
type="image",
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=False,
|
||||
count=1,
|
||||
)
|
||||
with get_db_session() as session:
|
||||
new_record = Images(
|
||||
image_hash=image_hash,
|
||||
description="",
|
||||
full_path=file_path,
|
||||
image_type=ImageType.IMAGE,
|
||||
query_count=1,
|
||||
is_registered=False,
|
||||
is_banned=False,
|
||||
record_time=datetime.fromtimestamp(current_timestamp),
|
||||
vlm_processed=False,
|
||||
)
|
||||
session.add(new_record)
|
||||
|
||||
# 启动异步VLM处理
|
||||
await self._process_image_with_vlm(image_id, image_base64)
|
||||
@@ -647,17 +652,26 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 获取当前图片记录
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
with get_db_session() as session:
|
||||
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
|
||||
if image is None:
|
||||
logger.warning(f"未找到图片记录: {image_id}")
|
||||
return
|
||||
|
||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||
existing_with_description = Images.get_or_none(
|
||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
(col(Images.image_hash) == image_hash)
|
||||
& (col(Images.description).is_not(None))
|
||||
& (col(Images.description) != "")
|
||||
)
|
||||
existing_with_description = session.exec(statement).first()
|
||||
if existing_with_description and existing_with_description.id != image.id:
|
||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||
image.description = existing_with_description.description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
# 同时保存到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||
return
|
||||
@@ -667,11 +681,12 @@ class ImageManager:
|
||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
||||
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||
|
||||
# 构建prompt
|
||||
prompt = global_config.personality.visual_style
|
||||
@@ -692,7 +707,8 @@ class ImageManager:
|
||||
# 更新数据库
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
with get_db_session() as session:
|
||||
session.add(image)
|
||||
|
||||
# 保存描述到ImageDescriptions表作为备用缓存
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
Reference in New Issue
Block a user