Final Commit Before Rdev

This commit is contained in:
UnCLAS-Prommer
2026-03-11 00:14:18 +08:00
committed by SengokuCola
parent e1e296491c
commit 8b9cda4296
10 changed files with 662 additions and 1348 deletions

9
Plan.md Normal file
View File

@@ -0,0 +1,9 @@
Context 在消息接收的时候就进行解析,不再放到 MaiMessage 里面,由消息注册的时候直接进去注册
- [ ] 实现`update_chat_context`方法,主要关注`format_info`
1. **预计不对发送的时候进行`accept_format`的格式判断**,希望所有消息适配器接收的时候做一下不兼容内容主动丢弃
2. 在发送消息的时候进行`accept_format`的判断,判断不兼容内容是否存在,如果存在则丢弃掉
- [ ] 实现 status_api

View File

@@ -1,18 +1,9 @@
from asyncio import Task
from datetime import datetime
from maim_message import (
MessageBase,
UserInfo as MaimUserInfo,
GroupInfo as MaimGroupInfo,
BaseMessageInfo as MaimBaseMessageInfo,
Seg,
)
from rich.traceback import install
from sqlmodel import select
from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING
from typing import List, Dict, Tuple, Sequence
import asyncio
import time
from src.common.logger import get_logger
from src.common.database.database import get_db_session
@@ -28,10 +19,6 @@ from src.common.data_models.message_component_data_model import (
ForwardNodeComponent,
StandardMessageComponents,
)
from src.common.utils.utils_message import MessageUtils
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3)
@@ -220,166 +207,3 @@ class SessionMessage(MaiMessage):
else:
processed_texts.append(result)
return " ".join(processed_texts)
class MessageSending(MaiMessage):
"""发送状态的消息类,继承 MaiMessage 基类。
用于构建、处理和发送机器人的回复消息。
复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法,
额外管理发送专属的会话信息和控制字段。
"""
def __init__(
self,
message_id: str,
session: "BotChatSession",
bot_user_info: UserInfo,
message_segment: Seg,
sender_info: Optional[UserInfo] = None,
reply: Optional[MaiMessage] = None,
display_message: str = "",
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
reply_to: Optional[str] = None,
selected_expressions: Optional[List[int]] = None,
):
# 初始化 MaiMessage 基类
super().__init__(message_id=message_id, timestamp=datetime.now())
# 发送专属字段
self.session = session
self.sender_info = sender_info
self.message_segment = message_segment
self.reply = reply
self.is_head = is_head
self.thinking_start_time = thinking_start_time
self.selected_expressions = selected_expressions
self.reply_to_message_id: Optional[str] = reply.message_id if reply else None
self.interest_value: float = 0.0
# 填充 MaiMessage 标准字段
self.platform = session.platform
self.session_id = session.session_id
self.is_emoji = is_emoji
self.reply_to = reply_to
self.display_message = display_message
self.processed_plain_text = ""
# 构建 message_infoDB 存储时 user_info 始终为 bot 信息
# 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理
group_info = self._resolve_group_info()
self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info)
# bot_user_info 单独保存to_maim_message 覆写时还需要
self.bot_user_info = bot_user_info
# 将 Seg 转换为 MessageSequence供基类的 to_db_instance / to_maim_message 使用
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=message_segment)
)
self.initialized = True
def _resolve_group_info(self) -> Optional[GroupInfo]:
"""从 session 中解析群信息"""
if not self.session.group_id:
return None
group_name = ""
if (
self.session.context
and self.session.context.message
and self.session.context.message.message_info.group_info
):
group_name = self.session.context.message.message_info.group_info.group_name
return GroupInfo(group_id=self.session.group_id, group_name=group_name)
async def process(self) -> None:
"""处理消息段,生成 processed_plain_text使用 SessionMessage 的组件处理能力)"""
# 同步 message_segment → raw_message插件可能修改了 message_segment
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
if self.raw_message and self.raw_message.components:
tasks = [self._process_component(c) for c in self.raw_message.components]
results = await asyncio.gather(*tasks, return_exceptions=True)
texts = []
for r in results:
if isinstance(r, BaseException):
logger.error(f"处理发送消息组件时发生错误: {r}")
elif r:
texts.append(r)
self.processed_plain_text = " ".join(texts)
async def _process_component(self, component: StandardMessageComponents) -> str:
"""简单处理单个标准组件为纯文本描述"""
if isinstance(component, TextComponent):
return component.text
elif isinstance(component, ImageComponent):
return "[图片]"
elif isinstance(component, EmojiComponent):
return "[表情包]"
elif isinstance(component, VoiceComponent):
return "[语音]"
elif isinstance(component, AtComponent):
return f"[@{component.target_user_id}]"
elif isinstance(component, ReplyComponent):
return ""
else:
return f"[{type(component).__name__}]"
def build_reply(self) -> None:
"""构建回复消息段,在 message_segment 前插入 reply 段"""
if self.reply:
self.reply_to_message_id = self.reply.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_id),
self.message_segment,
],
)
# 同步更新 raw_message
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
async def to_maim_message(self) -> MessageBase:
"""覆写基类方法:发送消息需要特殊处理 user_info私聊/群聊差异)"""
maim_bot_user_info = MaimUserInfo(
user_id=self.bot_user_info.user_id,
user_nickname=self.bot_user_info.user_nickname,
user_cardname=self.bot_user_info.user_cardname,
platform=self.platform,
)
maim_group_info = None
if self.message_info.group_info:
maim_group_info = MaimGroupInfo(
group_id=self.message_info.group_info.group_id,
group_name=self.message_info.group_info.group_name,
platform=self.platform,
)
# 私聊时 user_info 填接收者信息sender_info群聊时填 bot
if maim_group_info is None and self.sender_info:
msg_user_info = MaimUserInfo(
user_id=self.sender_info.user_id,
user_nickname=self.sender_info.user_nickname,
user_cardname=self.sender_info.user_cardname,
platform=self.platform,
)
else:
msg_user_info = maim_bot_user_info
maim_msg_info = MaimBaseMessageInfo(
platform=self.platform,
message_id=self.message_id,
time=time.time(),
group_info=maim_group_info,
user_info=msg_user_info,
)
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))

View File

@@ -1,561 +0,0 @@
import time
import asyncio
import urllib3
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install
from typing import Optional, Any, List
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.utils_image import get_image_manager
from src.common.utils.utils_voice import get_voice_text
from .chat_stream import ChatStream
install(extra_lines=3)
logger = get_logger("chat_message")
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# VLM 处理并发限制(避免同时处理太多图片导致卡死)
_vlm_semaphore = asyncio.Semaphore(3)
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
reply: Optional["MessageRecv"] = None,
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=current_timestamp,
group_info=chat_stream.group_info,
user_info=user_info,
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
# 回复消息
self.reply = reply
# async def _process_message_segments(self, segment: Seg) -> str:
# # sourcery skip: remove-unnecessary-else, swap-if-else-branches
# """递归处理消息段,转换为文字描述
# Args:
# segment: 要处理的消息段
# Returns:
# str: 处理后的文本
# """
# if segment.type == "seglist":
# # 处理消息段列表 - 使用并行处理提升性能
# tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore
# results = await asyncio.gather(*tasks, return_exceptions=True)
# segments_text = []
# for result in results:
# if isinstance(result, Exception):
# logger.error(f"处理消息段时出错: {result}")
# continue
# if result:
# segments_text.append(result)
# return " ".join(segments_text)
# elif segment.type == "forward":
# # 处理转发消息 - 使用并行处理
# async def process_forward_node(node_dict):
# message = MessageBase.from_dict(node_dict) # type: ignore
# processed_text = await self._process_message_segments(message.message_segment)
# if processed_text:
# return f"{global_config.bot.nickname}: {processed_text}"
# return None
# tasks = [process_forward_node(node_dict) for node_dict in segment.data]
# results = await asyncio.gather(*tasks, return_exceptions=True)
# segments_text = []
# for result in results:
# if isinstance(result, Exception):
# logger.error(f"处理转发节点时出错: {result}")
# continue
# if result:
# segments_text.append(result)
# return "[合并消息]: " + "\n-- ".join(segments_text)
# else:
# # 处理单个消息段
# return await self._process_single_segment(segment) # type: ignore
# @abstractmethod
# async def _process_single_segment(self, segment) -> str:
# pass
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.is_emoji = False
self.has_emoji = False
self.is_picid = False
self.has_picid = False
self.is_voice = False
self.is_mentioned = None
self.is_at = False
self.reply_probability_boost = 0.0
self.is_notify = False
self.is_command = False
self.intercept_message_level = 0
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
self.key_words = []
self.key_words_lite = []
# 兼容适配器通过 additional_config 传入的 @ 标记
try:
msg_info_dict = message_dict.get("message_info", {})
add_cfg = msg_info_dict.get("additional_config") or {}
if isinstance(add_cfg, dict) and add_cfg.get("at_bot"):
# 标记为被提及,提高后续回复优先级
self.is_mentioned = True # type: ignore
except Exception:
pass
def update_chat_stream(self, chat_stream: "ChatStream"):
self.chat_stream = chat_stream
# async def process(self) -> None:
# """处理消息内容,生成纯文本和详细文本
# 这个方法必须在创建实例后显式调用,因为它包含异步操作。
# """
# # print(f"self.message_segment: {self.message_segment}")
# self.processed_plain_text = await self._process_message_segments(self.message_segment)
# async def _process_single_segment(self, segment: Seg) -> str:
# """处理单个消息段
# Args:
# segment: 消息段
# Returns:
# str: 处理后的文本
# """
# try:
# if segment.type == "text":
# self.is_picid = False
# self.is_emoji = False
# return segment.data # type: ignore
# elif segment.type == "image":
# # 如果是base64图片数据
# if isinstance(segment.data, str):
# self.has_picid = True
# self.is_picid = True
# self.is_emoji = False
# image_manager = get_image_manager()
# # 使用 semaphore 限制 VLM 并发,避免同时处理太多图片
# async with _vlm_semaphore:
# _, processed_text = await image_manager.process_image(segment.data)
# return processed_text
# return "[发了一张图片,网卡了加载不出来]"
# elif segment.type == "emoji":
# self.has_emoji = True
# self.is_emoji = True
# self.is_picid = False
# self.is_voice = False
# if isinstance(segment.data, str):
# # 使用 semaphore 限制 VLM 并发
# async with _vlm_semaphore:
# return await get_image_manager().get_emoji_description(segment.data)
# return "[发了一个表情包,网卡了加载不出来]"
# elif segment.type == "voice":
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = True
# if isinstance(segment.data, str):
# return await get_voice_text(segment.data)
# return "[发了一段语音,网卡了加载不出来]"
# elif segment.type == "mention_bot":
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = False
# self.is_mentioned = float(segment.data) # type: ignore
# return ""
# elif segment.type == "priority_info":
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = False
# if isinstance(segment.data, dict):
# # 处理优先级信息
# self.priority_mode = "priority"
# self.priority_info = segment.data
# """
# {
# 'message_type': 'vip', # vip or normal
# 'message_priority': 1.0, # 优先级大为优先float
# }
# """
# return ""
# elif segment.type == "video_card":
# # 处理视频卡片消息
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = False
# if isinstance(segment.data, dict):
# file_name = segment.data.get("file", "未知视频")
# file_size = segment.data.get("file_size", "")
# url = segment.data.get("url", "")
# text = f"[视频: {file_name}"
# if file_size:
# text += f", 大小: {file_size}字节"
# text += "]"
# if url:
# text += f" 链接: {url}"
# return text
# return "[视频]"
# elif segment.type == "music_card":
# # 处理音乐卡片消息
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = False
# if isinstance(segment.data, dict):
# title = segment.data.get("title", "未知歌曲")
# singer = segment.data.get("singer", "")
# tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
# jump_url = segment.data.get("jump_url", "")
# music_url = segment.data.get("music_url", "")
# text = f"[音乐: {title}"
# if singer:
# text += f" - {singer}"
# if tag:
# text += f" ({tag})"
# text += "]"
# if jump_url:
# text += f" 跳转链接: {jump_url}"
# if music_url:
# text += f" 音乐链接: {music_url}"
# return text
# return "[音乐]"
# elif segment.type == "miniapp_card":
# # 处理小程序分享卡片如B站视频分享
# self.is_picid = False
# self.is_emoji = False
# self.is_voice = False
# if isinstance(segment.data, dict):
# title = segment.data.get("title", "") # 小程序名称
# desc = segment.data.get("desc", "") # 内容描述
# source_url = segment.data.get("source_url", "") # 原始链接
# url = segment.data.get("url", "") # 小程序链接
# text = "[小程序分享"
# if title:
# text += f" - {title}"
# text += "]"
# if desc:
# text += f" {desc}"
# if source_url:
# text += f" 链接: {source_url}"
# elif url:
# text += f" 链接: {url}"
# return text
# return "[小程序分享]"
# else:
# return ""
# except Exception as e:
# logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
# return f"[处理失败的{segment.type}消息]"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
message_id=message_id,
timestamp=timestamp,
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
)
# 处理状态相关属性
self.thinking_start_time = thinking_start_time
self.thinking_time = 0
# def update_thinking_time(self) -> float:
# """更新思考时间"""
# self.thinking_time = round(time.time() - self.thinking_start_time, 2)
# return self.thinking_time
# async def _process_single_segment(self, segment: Seg) -> str:
# """处理单个消息段
# Args:
# segment: 要处理的消息段
# Returns:
# str: 处理后的文本
# """
# try:
# if segment.type == "text":
# return segment.data # type: ignore
# elif segment.type == "image":
# # 如果是base64图片数据
# if isinstance(segment.data, str):
# return await get_image_manager().get_image_description(segment.data)
# return "[图片,网卡了加载不出来]"
# elif segment.type == "emoji":
# if isinstance(segment.data, str):
# return await get_image_manager().get_emoji_tag(segment.data)
# return "[表情,网卡了加载不出来]"
# elif segment.type == "voice":
# if isinstance(segment.data, str):
# return await get_voice_text(segment.data)
# return "[发了一段语音,网卡了加载不出来]"
# elif segment.type == "at":
# return f"[@{segment.data}]"
# elif segment.type == "reply":
# if self.reply and hasattr(self.reply, "processed_plain_text"):
# # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# # print(f"reply: {self.reply}")
# return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
# return ""
# else:
# return f"[{segment.type}:{str(segment.data)}]"
# except Exception as e:
# logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
# return f"[处理失败的{segment.type}消息]"
# def _generate_detailed_text(self) -> str:
# """生成详细文本,包含时间和用户信息"""
# # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
# timestamp = self.message_info.time
# user_info = self.message_info.user_info
# name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
# return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息
message_segment: Seg,
display_message: str = "",
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
selected_expressions: Optional[List[int]] = None,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
thinking_start_time=thinking_start_time,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
self.reply_to = reply_to
# 用于显示发送内容与显示不一致的情况
self.display_message = display_message
self.interest_value = 0.0
self.selected_expressions = selected_expressions
def build_reply(self):
"""设置回复消息"""
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
# def to_dict(self):
# ret = super().to_dict()
# ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
# return ret
# def is_private_message(self) -> bool:
# """判断是否为私聊消息"""
# return self.message_info.group_info is None or self.message_info.group_info.group_id is None
# @dataclass
# class MessageSet:
# """消息集合类,可以存储多个发送消息"""
# def __init__(self, chat_stream: "ChatStream", message_id: str):
# self.chat_stream = chat_stream
# self.message_id = message_id
# self.messages: list[MessageSending] = []
# self.time = round(time.time(), 3) # 保留3位小数
# def add_message(self, message: MessageSending) -> None:
# """添加消息到集合"""
# if not isinstance(message, MessageSending):
# raise TypeError("MessageSet只能添加MessageSending类型的消息")
# self.messages.append(message)
# self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
# def get_message_by_index(self, index: int) -> Optional[MessageSending]:
# """通过索引获取消息"""
# return self.messages[index] if 0 <= index < len(self.messages) else None
# def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
# """获取最接近指定时间的消息"""
# if not self.messages:
# return None
# left, right = 0, len(self.messages) - 1
# while left < right:
# mid = (left + right) // 2
# if self.messages[mid].message_info.time < target_time: # type: ignore
# left = mid + 1
# else:
# right = mid
# return self.messages[left]
# def clear_messages(self) -> None:
# """清空所有消息"""
# self.messages.clear()
# def remove_message(self, message: MessageSending) -> bool:
# """移除指定消息"""
# if message in self.messages:
# self.messages.remove(message)
# return True
# return False
# def __str__(self) -> str:
# return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
# def __len__(self) -> int:
# return len(self.messages)
# def message_recv_from_dict(message_dict: dict) -> MessageRecv:
# return MessageRecv(message_dict)
# def message_from_db_dict(db_dict: dict) -> MessageRecv:
# """从数据库字典创建MessageRecv实例"""
# # 转换扁平的数据库字典为嵌套结构
# message_info_dict = {
# "platform": db_dict.get("chat_info_platform"),
# "message_id": db_dict.get("message_id"),
# "time": db_dict.get("time"),
# "group_info": {
# "platform": db_dict.get("chat_info_group_platform"),
# "group_id": db_dict.get("chat_info_group_id"),
# "group_name": db_dict.get("chat_info_group_name"),
# },
# "user_info": {
# "platform": db_dict.get("user_platform"),
# "user_id": db_dict.get("user_id"),
# "user_nickname": db_dict.get("user_nickname"),
# "user_cardname": db_dict.get("user_cardname"),
# },
# }
# processed_text = db_dict.get("processed_plain_text", "")
# # 构建 MessageRecv 需要的字典
# recv_dict = {
# "message_info": message_info_dict,
# "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
# "raw_message": None, # 数据库中未存储原始消息
# "processed_plain_text": processed_text,
# }
# # 创建 MessageRecv 实例
# msg = MessageRecv(recv_dict)
# # 从数据库字典中填充其他可选字段
# msg.interest_value = db_dict.get("interest_value", 0.0)
# msg.is_mentioned = db_dict.get("is_mentioned")
# msg.priority_mode = db_dict.get("priority_mode", "interest")
# msg.priority_info = db_dict.get("priority_info")
# msg.is_emoji = db_dict.get("is_emoji", False)
# msg.is_picid = db_dict.get("is_picid", False)
# return msg

View File

@@ -1,13 +1,14 @@
import asyncio
import traceback
from rich.traceback import install
from maim_message import Seg
from typing import Optional
import asyncio
from src.common.message_server.api import get_global_api
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.chat.message_receive.message_old import MessageSending
from src.chat.message_receive.message import SessionMessage
from src.common.data_models.message_component_data_model import ReplyComponent
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
@@ -21,267 +22,267 @@ _webui_chat_broadcaster = None
# 虚拟群 ID 前缀(与 chat_routes.py 保持一致)
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# TODO: 重构完成后完成webui相关
# def get_webui_chat_broadcaster():
# """获取 WebUI 聊天室广播器"""
# global _webui_chat_broadcaster
# if _webui_chat_broadcaster is None:
# try:
# from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
def get_webui_chat_broadcaster():
"""获取 WebUI 聊天室广播器"""
global _webui_chat_broadcaster
if _webui_chat_broadcaster is None:
try:
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
except ImportError:
_webui_chat_broadcaster = (None, None)
return _webui_chat_broadcaster
# _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
# except ImportError:
# _webui_chat_broadcaster = (None, None)
# return _webui_chat_broadcaster
def is_webui_virtual_group(group_id: str) -> bool:
"""检查是否是 WebUI 虚拟群"""
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
# def is_webui_virtual_group(group_id: str) -> bool:
# """检查是否是 WebUI 虚拟群"""
# return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
# def parse_message_segments(segment) -> list:
# """解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
# 参考 NapCat 适配器的消息解析逻辑
Args:
segment: Seg 消息段对象
# Args:
# segment: Seg 消息段对象
Returns:
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
"""
# Returns:
# list: 消息段列表,每个元素为 {"type": "...", "data": ...}
# """
result = []
# result = []
if segment is None:
return result
# if segment is None:
# return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
for seg in segment.data:
result.extend(parse_message_segments(seg))
elif segment.type == "text":
# 文本消息
if segment.data:
result.append({"type": "text", "data": segment.data})
elif segment.type == "image":
# 图片消息base64
if segment.data:
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
elif segment.type == "emoji":
# 表情包消息base64
if segment.data:
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
elif segment.type == "imageurl":
# 图片链接消息
if segment.data:
result.append({"type": "image", "data": segment.data})
elif segment.type == "face":
# 原生表情
result.append({"type": "face", "data": segment.data})
elif segment.type == "voice":
# 语音消息base64
if segment.data:
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
elif segment.type == "voiceurl":
# 语音链接
if segment.data:
result.append({"type": "voice", "data": segment.data})
elif segment.type == "video":
# 视频消息base64
if segment.data:
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
elif segment.type == "videourl":
# 视频链接
if segment.data:
result.append({"type": "video", "data": segment.data})
elif segment.type == "music":
# 音乐消息
result.append({"type": "music", "data": segment.data})
elif segment.type == "file":
# 文件消息
result.append({"type": "file", "data": segment.data})
elif segment.type == "reply":
# 回复消息
result.append({"type": "reply", "data": segment.data})
elif segment.type == "forward":
# 转发消息
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
# if segment.type == "seglist":
# # 处理消息段列表
# if segment.data:
# for seg in segment.data:
# result.extend(parse_message_segments(seg))
# elif segment.type == "text":
# # 文本消息
# if segment.data:
# result.append({"type": "text", "data": segment.data})
# elif segment.type == "image":
# # 图片消息base64
# if segment.data:
# result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
# elif segment.type == "emoji":
# # 表情包消息base64
# if segment.data:
# result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
# elif segment.type == "imageurl":
# # 图片链接消息
# if segment.data:
# result.append({"type": "image", "data": segment.data})
# elif segment.type == "face":
# # 原生表情
# result.append({"type": "face", "data": segment.data})
# elif segment.type == "voice":
# # 语音消息base64
# if segment.data:
# result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
# elif segment.type == "voiceurl":
# # 语音链接
# if segment.data:
# result.append({"type": "voice", "data": segment.data})
# elif segment.type == "video":
# # 视频消息base64
# if segment.data:
# result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
# elif segment.type == "videourl":
# # 视频链接
# if segment.data:
# result.append({"type": "video", "data": segment.data})
# elif segment.type == "music":
# # 音乐消息
# result.append({"type": "music", "data": segment.data})
# elif segment.type == "file":
# # 文件消息
# result.append({"type": "file", "data": segment.data})
# elif segment.type == "reply":
# # 回复消息
# result.append({"type": "reply", "data": segment.data})
# elif segment.type == "forward":
# # 转发消息
# forward_items = []
# if segment.data:
# for item in segment.data:
# forward_items.append(
# {
# "content": parse_message_segments(item.get("message_segment", {}))
# if isinstance(item, dict)
# else []
# }
# )
# result.append({"type": "forward", "data": forward_items})
# else:
# # 未知类型,尝试作为文本处理
# if segment.data:
# result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
# return result
async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.platform
group_id = message.session.group_id
# async def _send_message(message: MessageSending, show_log=True) -> bool:
# """合并后的消息发送函数包含WS发送和日志记录"""
# message_preview = truncate_message(message.processed_plain_text, max_length=200)
# platform = message.platform
# group_id = message.session.group_id
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
# try:
# # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
# chat_manager, webui_platform = get_webui_chat_broadcaster()
# is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time
from src.config.config import global_config
# if is_webui_message and chat_manager is not None:
# # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
# import time
# from src.config.config import global_config
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
# # 解析消息段,获取富文本内容
# message_segments = parse_message_segments(message.message_segment)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段
if len(message_segments) == 1 and message_segments[0].get("type") == "text":
message_type = "text"
segments = None
else:
message_type = "rich"
segments = message_segments
# # 判断消息类型
# # 如果只有一个文本段,使用简单的 text 类型
# # 否则使用 rich 类型,包含完整的消息段
# if len(message_segments) == 1 and message_segments[0].get("type") == "text":
# message_type = "text"
# segments = None
# else:
# message_type = "rich"
# segments = message_segments
await chat_manager.broadcast(
{
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": message_type,
"segments": segments, # 富文本消息段
"timestamp": time.time(),
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
"sender": {
"name": global_config.bot.nickname,
"avatar": None,
"is_bot": True,
},
}
)
# await chat_manager.broadcast(
# {
# "type": "bot_message",
# "content": message.processed_plain_text,
# "message_type": message_type,
# "segments": segments, # 富文本消息段
# "timestamp": time.time(),
# "group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
# "sender": {
# "name": global_config.bot.nickname,
# "avatar": None,
# "is_bot": True,
# },
# }
# )
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
# 无需手动保存
# # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
# # 无需手动保存
if show_log:
if is_webui_virtual_group(group_id):
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})")
else:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True
# if show_log:
# if is_webui_virtual_group(group_id):
# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})")
# else:
# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
# return True
# Fallback 逻辑: 尝试通过 API Server 发送
async def send_with_new_api(legacy_exception=None):
try:
from src.config.config import global_config
# # Fallback 逻辑: 尝试通过 API Server 发送
# async def send_with_new_api(legacy_exception=None):
# try:
# from src.config.config import global_config
# 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server:
logger.debug("[API Server Fallback] API Server未开启跳过fallback")
if legacy_exception:
raise legacy_exception
return False
# # 如果未开启 API Server直接跳过 Fallback
# if not global_config.maim_message.enable_api_server:
# logger.debug("[API Server Fallback] API Server未开启跳过fallback")
# if legacy_exception:
# raise legacy_exception
# return False
global_api = get_global_api()
extra_server = getattr(global_api, "extra_server", None)
# global_api = get_global_api()
# extra_server = getattr(global_api, "extra_server", None)
if not extra_server:
logger.warning("[API Server Fallback] extra_server不存在")
if legacy_exception:
raise legacy_exception
return False
# if not extra_server:
# logger.warning("[API Server Fallback] extra_server不存在")
# if legacy_exception:
# raise legacy_exception
# return False
if not extra_server.is_running():
logger.warning("[API Server Fallback] extra_server未运行")
if legacy_exception:
raise legacy_exception
return False
# if not extra_server.is_running():
# logger.warning("[API Server Fallback] extra_server未运行")
# if legacy_exception:
# raise legacy_exception
# return False
# Fallback: 使用极其简单的 Platform -> API Key 映射
# 只有收到过该平台的消息,我们才知道该平台的 API Key才能回传消息
platform_map = getattr(global_api, "platform_map", {})
logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'")
target_api_key = platform_map.get(platform)
# # Fallback: 使用极其简单的 Platform -> API Key 映射
# # 只有收到过该平台的消息,我们才知道该平台的 API Key才能回传消息
# platform_map = getattr(global_api, "platform_map", {})
# logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'")
# target_api_key = platform_map.get(platform)
if not target_api_key:
logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射")
if legacy_exception:
raise legacy_exception
return False
# if not target_api_key:
# logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射")
# if legacy_exception:
# raise legacy_exception
# return False
# 使用 MessageConverter 转换为 API 消息
from maim_message import MessageConverter
# # 使用 MessageConverter 转换为 API 消息
# from maim_message import MessageConverter
# 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
message_base = await message.to_maim_message()
# # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
# message_base = await message.to_maim_message()
api_message = MessageConverter.to_api_send(
message=message_base,
api_key=target_api_key,
platform=platform,
)
# api_message = MessageConverter.to_api_send(
# message=message_base,
# api_key=target_api_key,
# platform=platform,
# )
# 直接调用 Server 的 send_message 接口,它会自动处理路由
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
results = await extra_server.send_message(api_message)
logger.debug(f"[API Server Fallback] 发送结果: {results}")
# # 直接调用 Server 的 send_message 接口,它会自动处理路由
# logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
# results = await extra_server.send_message(api_message)
# logger.debug(f"[API Server Fallback] 发送结果: {results}")
# 检查是否有任何连接发送成功
if any(results.values()):
if show_log:
logger.info(
f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
)
return True
else:
logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}")
except Exception as e:
logger.error(f"[API Server Fallback] 发生异常: {e}")
import traceback
# # 检查是否有任何连接发送成功
# if any(results.values()):
# if show_log:
# logger.info(
# f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
# )
# return True
# else:
# logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}")
# except Exception as e:
# logger.error(f"[API Server Fallback] 发生异常: {e}")
# import traceback
logger.debug(traceback.format_exc())
# logger.debug(traceback.format_exc())
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
if legacy_exception:
raise legacy_exception
return False
# # 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
# if legacy_exception:
# raise legacy_exception
# return False
try:
message_base = await message.to_maim_message()
send_result = await get_global_api().send_message(message_base)
if send_result:
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
return True
else:
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
fallback_result = await send_with_new_api()
if fallback_result and show_log:
# Fallback成功的日志已在send_with_new_api中打印
pass
return fallback_result
# try:
# message_base = await message.to_maim_message()
# send_result = await get_global_api().send_message(message_base)
# if send_result:
# if show_log:
# logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
# return True
# else:
# # Legacy API 返回 False (发送失败但未报错),尝试 Fallback
# fallback_result = await send_with_new_api()
# if fallback_result and show_log:
# # Fallback成功的日志已在send_with_new_api中打印
# pass
# return fallback_result
except Exception as legacy_e:
# Legacy API 抛出异常,尝试 Fallback
# 如果 Fallback 也失败,将重新抛出 legacy_e
return await send_with_new_api(legacy_exception=legacy_e)
# except Exception as legacy_e:
# # Legacy API 抛出异常,尝试 Fallback
# # 如果 Fallback 也失败,将重新抛出 legacy_e
# return await send_with_new_api(legacy_exception=legacy_e)
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
traceback.print_exc()
raise e # 重新抛出其他异常
# except Exception as e:
# logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
# traceback.print_exc()
# raise e # 重新抛出其他异常
class UniversalMessageSender:
@@ -291,21 +292,26 @@ class UniversalMessageSender:
pass
async def send_message(
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
self,
message: "SessionMessage",
typing: bool = False,
set_reply: bool = False,
reply_message_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
):
"""
处理、发送并存储一条消息。
参数:
message: MessageSending 对象,待发送的消息。
message: MessageSession 对象,待发送的消息。
typing: 是否模拟打字等待。
set_reply: 是否构建回复引用消息。
用法:
- typing=True 时,发送前会有打字等待。
"""
if not message.session:
logger.error("消息缺少 session无法发送")
raise ValueError("消息缺少 session无法发送")
if not message.message_id:
logger.error("消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_id无法发送")
@@ -315,66 +321,62 @@ class UniversalMessageSender:
try:
if set_reply:
message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
if not reply_message_id:
raise ValueError("set_reply=True 时必须提供 reply_message_id")
message.raw_message.components.insert(0, ReplyComponent(reply_message_id))
from src.core.event_bus import event_bus
from src.chat.event_helpers import build_event_message
from src.core.types import EventType
# TODO: fix
# from src.core.event_bus import event_bus
# from src.chat.event_helpers import build_event_message
# from src.core.types import EventType
_event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND_PRE_PROCESS, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
return False
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。")
message.processed_plain_text = modified_message.plain_text
# _event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id)
# continue_flag, modified_message = await event_bus.emit(EventType.POST_SEND_PRE_PROCESS, _event_msg)
# if not continue_flag:
# logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
# return False
# if modified_message:
# if modified_message._modify_flags.modify_message_segments:
# message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
# if modified_message._modify_flags.modify_plain_text:
# logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。")
# message.processed_plain_text = modified_message.plain_text
await message.process()
_event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.POST_SEND, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
return False
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
# TODO: fix
# _event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id)
# continue_flag, modified_message = await event_bus.emit(EventType.POST_SEND, _event_msg)
# if not continue_flag:
# logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...")
# return False
# if modified_message:
# if modified_message._modify_flags.modify_message_segments:
# message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
# if modified_message._modify_flags.modify_plain_text:
# message.processed_plain_text = modified_message.plain_text
if typing:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
input_string=message.processed_plain_text, # type: ignore
is_emoji=message.is_emoji,
)
await asyncio.sleep(typing_time)
sent_msg = await _send_message(message, show_log=show_log)
if not sent_msg:
return False
# sent_msg = await _send_message(message, show_log=show_log)
# if not sent_msg:
# return False
_event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
continue_flag, modified_message = await event_bus.emit(
EventType.AFTER_SEND, _event_msg
)
if not continue_flag:
logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
return True
if modified_message:
if modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
if modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
# _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
# continue_flag, modified_message = await event_bus.emit(EventType.AFTER_SEND, _event_msg)
# if not continue_flag:
# logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...")
# return True
# if modified_message:
# if modified_message._modify_flags.modify_message_segments:
# message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
# if modified_message._modify_flags.modify_plain_text:
# message.processed_plain_text = modified_message.plain_text
if storage_message:
with get_db_session() as db_session:

View File

@@ -1,3 +1,4 @@
# TODO: 完全删除此文件,将所有方法该合并的合并。
import time
import random
import re
@@ -19,7 +20,7 @@ install(extra_lines=3)
logger = get_logger("chat_message_builder")
def replace_user_references(
def replace_user_references( # TODO: 整合此函数
content: Optional[str],
platform: str,
name_resolver: Optional[Callable[[str, str], str]] = None,
@@ -262,102 +263,103 @@ def get_actions_by_timestamp_with_chat_inclusive(
return [action.model_dump() for action in actions]
def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
# 获取所有消息只取chat_id字段
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
if not all_msgs:
return []
# 随机选一条
msg = random.choice(all_msgs)
chat_id = msg.chat_id
timestamp_start = msg.time
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
# TODO: 整合为统一函数由参数控制仿照build_readable_message
# def get_raw_msg_by_timestamp_random(
# timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
# ) -> List[DatabaseMessages]:
# """
# 先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
# """
# # 获取所有消息只取chat_id字段
# all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
# if not all_msgs:
# return []
# # 随机选一条
# msg = random.choice(all_msgs)
# chat_id = msg.chat_id
# timestamp_start = msg.time
# # 用 chat_id 获取该聊天在指定时间戳范围内的消息
# return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'
"""
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
# def get_raw_msg_by_timestamp_with_users(
# timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
# ) -> List[DatabaseMessages]:
# """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
# limit: 限制返回的消息数量0为不限制
# limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。
# """
# filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# # 只有当 limit 为 0 时才应用外部 sort
# sort_order = [("time", 1)] if limit == 0 else None
# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
# def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]:
# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
# limit: 限制返回的消息数量0为不限制
# """
# filter_query = {"time": {"$lt": timestamp}}
# sort_order = [("time", 1)]
# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)
# def get_raw_msg_before_timestamp_with_chat(
# chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None
# ) -> List[DatabaseMessages]:
# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
# limit: 限制返回的消息数量0为不限制
# """
# filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
# sort_order = [("time", 1)]
# return find_messages(
# message_filter=filter_query,
# sort=sort_order,
# limit=limit,
# filter_intercept_message_level=filter_intercept_message_level,
# )
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
# def get_raw_msg_before_timestamp_with_users(
# timestamp: float, person_ids: List[str], limit: int = 0
# ) -> List[DatabaseMessages]:
# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表
# limit: 限制返回的消息数量0为不限制
# """
# filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
# sort_order = [("time", 1)]
# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
"""
# 确定有效的结束时间戳
_timestamp_end = timestamp_end if timestamp_end is not None else time.time()
# def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
# """
# 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
# 如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
# """
# # 确定有效的结束时间戳
# _timestamp_end = timestamp_end if timestamp_end is not None else time.time()
# 确保 timestamp_start < _timestamp_end
if timestamp_start >= _timestamp_end:
# logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.")
return 0 # 起始时间大于等于结束时间,没有新消息
# # 确保 timestamp_start < _timestamp_end
# if timestamp_start >= _timestamp_end:
# # logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.")
# return 0 # 起始时间大于等于结束时间,没有新消息
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
return count_messages(message_filter=filter_query)
# filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
# return count_messages(message_filter=filter_query)
def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
if not person_ids: # 保持空列表检查
return 0
filter_query = {
"chat_id": chat_id,
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
return count_messages(message_filter=filter_query)
# def num_new_messages_since_with_users(
# chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
# ) -> int:
# """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
# if not person_ids: # 保持空列表检查
# return 0
# filter_query = {
# "chat_id": chat_id,
# "time": {"$gt": timestamp_start, "$lt": timestamp_end},
# "user_id": {"$in": person_ids},
# }
# return count_messages(message_filter=filter_query)
def _build_readable_messages_internal(
@@ -563,40 +565,41 @@ def _build_readable_messages_internal(
)
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# sourcery skip: use-contextlib-suppress
"""
构建图片映射信息字符串,显示图片的具体描述内容
# 由MessageUtils._extract_pictures_from_message替代
# def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# # sourcery skip: use-contextlib-suppress
# """
# 构建图片映射信息字符串,显示图片的具体描述内容
Args:
pic_id_mapping: 图片ID到显示名称的映射字典
# Args:
# pic_id_mapping: 图片ID到显示名称的映射字典
Returns:
格式化的映射信息字符串
"""
if not pic_id_mapping:
return ""
# Returns:
# 格式化的映射信息字符串
# """
# if not pic_id_mapping:
# return ""
mapping_lines = []
# mapping_lines = []
# 按图片编号排序
sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
# # 按图片编号排序
# sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
for pic_id, display_name in sorted_items:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
# 如果查询失败,保持默认描述
pass
# for pic_id, display_name in sorted_items:
# # 从数据库中获取图片描述
# description = "内容正在阅读,请稍等"
# try:
# with get_db_session() as session:
# image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
# if image and image.description:
# description = image.description
# except Exception:
# # 如果查询失败,保持默认描述
# pass
mapping_lines.append(f"[{display_name}] 的内容:{description}")
# mapping_lines.append(f"[{display_name}] 的内容:{description}")
return "\n".join(mapping_lines)
# return "\n".join(mapping_lines)
def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "relative") -> str:
@@ -646,68 +649,69 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re
return "\n".join(output_lines)
async def build_readable_messages_with_list(
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_single: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
messages,
replace_bot_name,
timestamp_mode,
truncate,
pic_id_mapping=None,
pic_counter=1,
show_pic=True,
message_id_list=None,
pic_single=pic_single,
long_time_notice=False,
)
# 由MessageUtils里面的build_readable_message替代
# async def build_readable_messages_with_list(
# messages: List[DatabaseMessages],
# replace_bot_name: bool = True,
# timestamp_mode: str = "relative",
# truncate: bool = False,
# pic_single: bool = False,
# ) -> Tuple[str, List[Tuple[float, str, str]]]:
# """
# 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
# 允许通过参数控制格式化行为。
# """
# formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
# messages,
# replace_bot_name,
# timestamp_mode,
# truncate,
# pic_id_mapping=None,
# pic_counter=1,
# show_pic=True,
# message_id_list=None,
# pic_single=pic_single,
# long_time_notice=False,
# )
if not pic_single:
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
# if not pic_single:
# if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
# formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
return formatted_string, details_list
# return formatted_string, details_list
# 由MessageUtils里面的build_readable_message替代
# def build_readable_messages_with_id(
# messages: List[DatabaseMessages],
# replace_bot_name: bool = True,
# timestamp_mode: str = "relative",
# read_mark: float = 0.0,
# truncate: bool = False,
# show_actions: bool = False,
# show_pic: bool = True,
# remove_emoji_stickers: bool = False,
# pic_single: bool = False,
# ) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
# """
# 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
# 允许通过参数控制格式化行为。
# """
# message_id_list = assign_message_ids(messages)
def build_readable_messages_with_id(
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
remove_emoji_stickers: bool = False,
pic_single: bool = False,
) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
"""
message_id_list = assign_message_ids(messages)
# formatted_string = build_readable_messages(
# messages=messages,
# replace_bot_name=replace_bot_name,
# timestamp_mode=timestamp_mode,
# truncate=truncate,
# show_actions=show_actions,
# show_pic=show_pic,
# read_mark=read_mark,
# message_id_list=message_id_list,
# remove_emoji_stickers=remove_emoji_stickers,
# pic_single=pic_single,
# )
formatted_string = build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
timestamp_mode=timestamp_mode,
truncate=truncate,
show_actions=show_actions,
show_pic=show_pic,
read_mark=read_mark,
message_id_list=message_id_list,
remove_emoji_stickers=remove_emoji_stickers,
pic_single=pic_single,
)
return formatted_string, message_id_list
# return formatted_string, message_id_list
def build_readable_messages(
@@ -903,111 +907,112 @@ def build_readable_messages(
return "".join(result_parts)
async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: bool = False) -> str:
"""
构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
"""
if not messages:
logger.warning("没有消息,无法构建匿名消息")
return ""
# 由MessageUtils里面的build_readable_message替代
# async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: bool = False) -> str:
# """
# 构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
# 处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
# """
# if not messages:
# logger.warning("没有消息,无法构建匿名消息")
# return ""
person_map = {}
current_char = ord("A")
output_lines = []
# person_map = {}
# current_char = ord("A")
# output_lines = []
# 图片ID映射字典
pic_id_mapping = {}
pic_counter = 1
# # 图片ID映射字典
# pic_id_mapping = {}
# pic_counter = 1
def process_pic_ids(content: str) -> str:
"""处理内容中的图片ID将其替换为[图片x]格式"""
nonlocal pic_counter
# def process_pic_ids(content: str) -> str:
# """处理内容中的图片ID将其替换为[图片x]格式"""
# nonlocal pic_counter
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
# # 匹配 [picid:xxxxx] 格式
# pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match):
nonlocal pic_counter
pic_id = match.group(1)
# def replace_pic_id(match):
# nonlocal pic_counter
# pic_id = match.group(1)
if pic_id not in pic_id_mapping:
pic_id_mapping[pic_id] = f"图片{pic_counter}"
pic_counter += 1
# if pic_id not in pic_id_mapping:
# pic_id_mapping[pic_id] = f"图片{pic_counter}"
# pic_counter += 1
return f"[{pic_id_mapping[pic_id]}]"
# return f"[{pic_id_mapping[pic_id]}]"
return re.sub(pic_pattern, replace_pic_id, content)
# return re.sub(pic_pattern, replace_pic_id, content)
def get_anon_name(platform, user_id):
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
# def get_anon_name(platform, user_id):
# # print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
# # print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
if (platform == "qq" and user_id == global_config.bot.qq_account) or (
platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
):
# print("SELF11111111111111")
return "SELF"
try:
person_id = get_person_id(platform, user_id)
except Exception as _e:
person_id = None
if not person_id:
return "?"
if person_id not in person_map:
nonlocal current_char
person_map[person_id] = chr(current_char)
current_char += 1
return person_map[person_id]
# if (platform == "qq" and user_id == global_config.bot.qq_account) or (
# platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "")
# ):
# # print("SELF11111111111111")
# return "SELF"
# try:
# person_id = get_person_id(platform, user_id)
# except Exception as _e:
# person_id = None
# if not person_id:
# return "?"
# if person_id not in person_map:
# nonlocal current_char
# person_map[person_id] = chr(current_char)
# current_char += 1
# return person_map[person_id]
for i, msg in enumerate(messages):
try:
platform = msg.chat_info.platform
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
# for i, msg in enumerate(messages):
# try:
# platform = msg.chat_info.platform
# user_id = msg.user_info.user_id
# content = msg.display_message or msg.processed_plain_text or ""
# 处理图片ID
content = process_pic_ids(content)
# # 处理图片ID
# content = process_pic_ids(content)
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
# anon_name = get_anon_name(platform, user_id)
# # print(f"anon_name:{anon_name}")
# 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器
def anon_name_resolver(platform: str, user_id: str) -> str:
try:
return get_anon_name(platform, user_id)
except Exception:
return "?"
# # 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器
# def anon_name_resolver(platform: str, user_id: str) -> str:
# try:
# return get_anon_name(platform, user_id)
# except Exception:
# return "?"
content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
# content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
# 构建消息头如果启用show_ids则添加序号
if show_ids:
header = f"[{i + 1}] {anon_name}"
else:
header = f"{anon_name}"
# # 构建消息头如果启用show_ids则添加序号
# if show_ids:
# header = f"[{i + 1}] {anon_name}说 "
# else:
# header = f"{anon_name}说 "
output_lines.append(header)
stripped_line = content.strip()
if stripped_line:
if stripped_line.endswith(""):
stripped_line = stripped_line[:-1]
output_lines.append(f"{stripped_line}")
# print(f"output_lines:{output_lines}")
output_lines.append("\n")
except Exception:
continue
# output_lines.append(header)
# stripped_line = content.strip()
# if stripped_line:
# if stripped_line.endswith("。"):
# stripped_line = stripped_line[:-1]
# output_lines.append(f"{stripped_line}")
# # print(f"output_lines:{output_lines}")
# output_lines.append("\n")
# except Exception:
# continue
# 在最前面添加图片映射信息
final_output_lines = []
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
if pic_mapping_info:
final_output_lines.append(pic_mapping_info)
final_output_lines.append("\n\n")
# # 在最前面添加图片映射信息
# final_output_lines = []
# pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
# if pic_mapping_info:
# final_output_lines.append(pic_mapping_info)
# final_output_lines.append("\n\n")
final_output_lines.extend(output_lines)
formatted_string = "".join(final_output_lines).strip()
return formatted_string
# final_output_lines.extend(output_lines)
# formatted_string = "".join(final_output_lines).strip()
# return formatted_string
async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:

View File

@@ -523,7 +523,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese
def calculate_typing_time(
input_string: str,
thinking_start_time: float,
# thinking_start_time: float,
chinese_time: float = 0.3,
english_time: float = 0.15,
is_emoji: bool = False,
@@ -556,8 +556,8 @@ def calculate_typing_time(
if is_emoji:
total_time = 1
if time.time() - thinking_start_time > 10:
total_time = 1
# if time.time() - thinking_start_time > 10:
# total_time = 1
# print(f"thinking_start_time:{thinking_start_time}")
# print(f"nowtime:{time.time()}")

View File

@@ -1,7 +1,7 @@
import json
from dataclasses import dataclass
# import json
# from dataclasses import dataclass
from . import BaseDataModel
# from . import BaseDataModel
# @dataclass
@@ -208,33 +208,33 @@ from . import BaseDataModel
# }
@dataclass(init=False)
class DatabaseActionRecords(BaseDataModel):
def __init__(
self,
action_id: str,
time: float,
action_name: str,
action_data: str,
action_done: bool,
action_build_into_prompt: bool,
action_prompt_display: str,
chat_id: str,
chat_info_stream_id: str,
chat_info_platform: str,
action_reasoning: str,
):
self.action_id = action_id
self.time = time
self.action_name = action_name
if isinstance(action_data, str):
self.action_data = json.loads(action_data)
else:
raise ValueError("action_data must be a JSON string")
self.action_done = action_done
self.action_build_into_prompt = action_build_into_prompt
self.action_prompt_display = action_prompt_display
self.chat_id = chat_id
self.chat_info_stream_id = chat_info_stream_id
self.chat_info_platform = chat_info_platform
self.action_reasoning = action_reasoning
# @dataclass(init=False)
# class DatabaseActionRecords(BaseDataModel):
# def __init__(
# self,
# action_id: str,
# time: float,
# action_name: str,
# action_data: str,
# action_done: bool,
# action_build_into_prompt: bool,
# action_prompt_display: str,
# chat_id: str,
# chat_info_stream_id: str,
# chat_info_platform: str,
# action_reasoning: str,
# ):
# self.action_id = action_id
# self.time = time
# self.action_name = action_name
# if isinstance(action_data, str):
# self.action_data = json.loads(action_data)
# else:
# raise ValueError("action_data must be a JSON string")
# self.action_done = action_done
# self.action_build_into_prompt = action_build_into_prompt
# self.action_prompt_display = action_prompt_display
# self.chat_id = chat_id
# self.chat_info_stream_id = chat_info_stream_id
# self.chat_info_platform = chat_info_platform
# self.action_reasoning = action_reasoning

View File

@@ -1,27 +1,28 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, TYPE_CHECKING
from . import BaseDataModel
# from dataclasses import dataclass, field
# from typing import Optional, Dict, TYPE_CHECKING
# from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
from src.core.types import ActionInfo
# if TYPE_CHECKING:
# from .database_data_model import DatabaseMessages
# from src.core.types import ActionInfo
# # @dataclass
# # class TargetPersonInfo(BaseDataModel):
# # platform: str = field(default_factory=str)
# # user_id: str = field(default_factory=str)
# # user_nickname: str = field(default_factory=str)
# # person_id: Optional[str] = None
# # person_name: Optional[str] = None
# @dataclass
# class TargetPersonInfo(BaseDataModel):
# platform: str = field(default_factory=str)
# user_id: str = field(default_factory=str)
# user_nickname: str = field(default_factory=str)
# person_id: Optional[str] = None
# person_name: Optional[str] = None
@dataclass
class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None
loop_start_time: Optional[float] = None
action_reasoning: Optional[str] = None
# class ActionPlannerInfo(BaseDataModel):
# action_type: str = field(default_factory=str)
# reasoning: Optional[str] = None
# action_data: Optional[Dict] = None
# action_message: Optional["DatabaseMessages"] = None
# available_actions: Optional[Dict[str, "ActionInfo"]] = None
# loop_start_time: Optional[float] = None
# action_reasoning: Optional[str] = None
# TODO: 重构

View File

@@ -1,22 +1,23 @@
from dataclasses import dataclass
from typing import Optional, List, TYPE_CHECKING, Dict, Any
# from dataclasses import dataclass
# from typing import Optional, List, TYPE_CHECKING, Dict, Any
from . import BaseDataModel
# from . import BaseDataModel
if TYPE_CHECKING:
from src.common.data_models.message_data_model import ReplySetModel
from src.llm_models.payload_content.tool_option import ToolCall
# if TYPE_CHECKING:
# from src.common.data_models.message_data_model import ReplySetModel
# from src.llm_models.payload_content.tool_option import ToolCall
@dataclass
class LLMGenerationDataModel(BaseDataModel):
content: Optional[str] = None
reasoning: Optional[str] = None
model: Optional[str] = None
tool_calls: Optional[List["ToolCall"]] = None
prompt: Optional[str] = None
selected_expressions: Optional[List[int]] = None
reply_set: Optional["ReplySetModel"] = None
timing: Optional[Dict[str, Any]] = None
processed_output: Optional[List[str]] = None
timing_logs: Optional[List[str]] = None
# @dataclass
# class LLMGenerationDataModel(BaseDataModel):
# content: Optional[str] = None
# reasoning: Optional[str] = None
# model: Optional[str] = None
# tool_calls: Optional[List["ToolCall"]] = None
# prompt: Optional[str] = None
# selected_expressions: Optional[List[int]] = None
# reply_set: Optional["ReplySetModel"] = None
# timing: Optional[Dict[str, Any]] = None
# processed_output: Optional[List[str]] = None
# timing_logs: Optional[List[str]] = None
# TODO: 重构

33
代码备忘.md Normal file
View File

@@ -0,0 +1,33 @@
# 代码备忘
- [ ] 检查EmojiManager的replace_an_emoji_by_llm传入的emoji是否真的是没有注册到db的
- [ ] According to a comment, MaiMBot's check_types() accesses format_info.accept_format without None check
- [ ] 如果需要更多的消息格式支持,更新列表如下:
- [ ] `src/common/utils/utils_message.py`中的`_parse_maim_message_segment_to_component`函数
- [ ] `src/common/data_models/message_component_model.py`中:
- [ ] 增加新的消息组件
- [ ] 看情况修改`StandardMessageComponents`的内容
- [ ] `MessageSequence``_dict_2_item``_item_2_dict`函数
- [ ] **取消了从chat_manager获取ChatSession时候的deepcopy看看会不会有问题**
# 迁移脚本备忘
- [ ] 迁移env到新版的bot_config管理
- [ ] 对于旧的消息需要重新计算其Hashmd5 -> sha256做好映射防止消息丢失
- [ ] PersonInfo的group_nickname名字改为group_cardname做好映射防止数据丢失同时存储的方式从`[{"group_id": str, "group_nick_name": str}]` -> `[{"group_id": str, "group_cardname": str}]`
- [ ] Expression中的`up_content`被移除了
- [ ] Jargon现在chat_id(session_id_list格式为`[["session_id", session_count]]`) -> session_id_dict`{"session_id": session_count}`),做好映射防止数据丢失
# 插件开发备忘
- [ ] 求各位插件开发不要在Dict里面塞一堆乱七八糟的东西免得数据库存储的时候一团糟
# Hack备忘
- [ ] 对于不符合内容审查要求的表情包,无法注册到数据库内,因此面对相同的非法表情包时,会导致反复识别。有成功注册的可能。
- [ ] 考虑到数据库记录表情包不合规判定有大模型误判的风险,因此保留现有的无法注册的情况,在再次遇到的时候重新识别。
- [ ] 目前在匿名化build message的时候如果一个被回复的消息包含了一个转发消息组件那么这个转发消息组件中的用户信息是不会被匿名化的后续需要修复这个问题。有时候感觉用正则是对的
- [ ] 可以考虑将消息保存的时候就将消息中的用户信息匿名化这样在后续的处理过程中就不需要担心匿名化的问题了同时也可以避免在build message的时候进行复杂的递归处理同时还要保存匿名映射表。
# 计算备忘
- [ ] emoji的emotion比较是基于编辑距离的考虑更换为基于语义的比较比如使用emoji的embedding进行比较以提高准确性和鲁棒性
- [ ] expression的相似度比较是基于LCS的Ratcliff-Obershelp算法考虑更换为基于语义的比较比如使用embedding进行比较以提高准确性和鲁棒性
- [ ] 为了保持代码的简洁性HFC无论任何情况都将初始化ExpressionReflectorExpressionLearnerJargonMiner实例无论配置文件中是否在此聊天流启用了他们。
- [ ] 可优化方向将其置为Optional在不启用的情况下不进行初始化
- [ ] 当配置文件重载时重新分析所有启用判定所有HFC进行并行检查将启用的进行实例化。不启用的实例化移除引用释放内存。