注释掉pfc内容,暂时恢复部分代码保证可启动性
This commit is contained in:
committed by
DrSmoothl
parent
568685758b
commit
272d0368b8
@@ -1,14 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Dict, TypedDict
|
from typing import Optional, Dict, TypedDict
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CyclePlanInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
class CyclePlanInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CycleActionInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
class CycleActionInfo(TypedDict): ... # TODO: 根据实际需要补充字段
|
||||||
|
|
||||||
|
|
||||||
@@ -19,11 +17,11 @@ class CycleDetail:
|
|||||||
cycle_id: int
|
cycle_id: int
|
||||||
thinking_id: str = ""
|
thinking_id: str = ""
|
||||||
"""思考ID"""
|
"""思考ID"""
|
||||||
start_time: float = time.time()
|
start_time: float = field(default_factory=time.time)
|
||||||
"""开始时间,单位为秒"""
|
"""开始时间,单位为秒"""
|
||||||
end_time: Optional[float] = None
|
end_time: Optional[float] = None
|
||||||
"""结束时间,单位为秒,None表示未结束"""
|
"""结束时间,单位为秒,None表示未结束"""
|
||||||
time_records: Dict[str, float] = {}
|
time_records: Dict[str, float] = field(default_factory=dict)
|
||||||
"""计时器记录,key为计时器名称,value为用时,单位为秒"""
|
"""计时器记录,key为计时器名称,value为用时,单位为秒"""
|
||||||
loop_plan_info: Optional[CyclePlanInfo] = None
|
loop_plan_info: Optional[CyclePlanInfo] = None
|
||||||
"""循环计划记录"""
|
"""循环计划记录"""
|
||||||
|
|||||||
@@ -5,18 +5,19 @@ import traceback
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_manager import chat_manager
|
from src.chat.message_receive.chat_manager import chat_manager
|
||||||
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
from src.chat.heart_flow.heartFC_chat import HeartFChatting
|
||||||
from src.chat.brain_chat.brain_chat import BrainChatting
|
# from src.chat.brain_chat.brain_chat import BrainChatting
|
||||||
|
|
||||||
logger = get_logger("heartflow")
|
logger = get_logger("heartflow")
|
||||||
|
|
||||||
|
# TODO: 恢复PFC,现在暂时禁用
|
||||||
class HeartflowManager:
|
class HeartflowManager:
|
||||||
"""主心流协调器,负责初始化并协调聊天,控制聊天属性"""
|
"""主心流协调器,负责初始化并协调聊天,控制聊天属性"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
# self.heartflow_chat_list: Dict[str, HeartFChatting | BrainChatting] = {}
|
||||||
|
self.heartflow_chat_list: Dict[str, HeartFChatting] = {}
|
||||||
|
|
||||||
async def get_or_create_heartflow_chat(self, session_id: str) -> Optional[HeartFChatting | BrainChatting]:
|
async def get_or_create_heartflow_chat(self, session_id: str): # -> Optional[HeartFChatting | BrainChatting]:
|
||||||
"""获取或创建一个新的HeartFChatting实例"""
|
"""获取或创建一个新的HeartFChatting实例"""
|
||||||
try:
|
try:
|
||||||
if chat := self.heartflow_chat_list.get(session_id):
|
if chat := self.heartflow_chat_list.get(session_id):
|
||||||
@@ -24,16 +25,17 @@ class HeartflowManager:
|
|||||||
chat_session = chat_manager.get_session_by_session_id(session_id)
|
chat_session = chat_manager.get_session_by_session_id(session_id)
|
||||||
if not chat_session:
|
if not chat_session:
|
||||||
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
|
raise ValueError(f"未找到 session_id={session_id} 的聊天流")
|
||||||
new_chat = (
|
# new_chat = (
|
||||||
HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
|
# HeartFChatting(session_id=session_id) if chat_session.group_id else BrainChatting(session_id=session_id)
|
||||||
)
|
# )
|
||||||
|
new_chat = HeartFChatting(session_id=session_id)
|
||||||
await new_chat.start()
|
await new_chat.start()
|
||||||
self.heartflow_chat_list[session_id] = new_chat
|
self.heartflow_chat_list[session_id] = new_chat
|
||||||
return new_chat
|
return new_chat
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
|
logger.error(f"创建心流聊天 {session_id} 失败: {e}", exc_info=True)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
raise e
|
||||||
|
|
||||||
def adjust_talk_frequency(self, session_id: str, frequency: float):
|
def adjust_talk_frequency(self, session_id: str, frequency: float):
|
||||||
"""调整指定聊天流的说话频率"""
|
"""调整指定聊天流的说话频率"""
|
||||||
|
|||||||
@@ -57,8 +57,11 @@ class HeartFCMessageReceiver:
|
|||||||
# message.is_at = is_at
|
# message.is_at = is_at
|
||||||
|
|
||||||
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||||
|
try:
|
||||||
await heartflow_manager.get_or_create_heartflow_chat(message.session_id)
|
chat = await heartflow_manager.get_or_create_heartflow_chat(message.session_id)
|
||||||
|
await chat.register_message(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"出现错误: {e}")
|
||||||
|
|
||||||
# 3. 日志记录
|
# 3. 日志记录
|
||||||
mes_name = group_info.group_name if group_info else "私聊"
|
mes_name = group_info.group_name if group_info else "私聊"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
|||||||
from src.common.utils.utils_message import MessageUtils
|
from src.common.utils.utils_message import MessageUtils
|
||||||
from src.common.utils.utils_session import SessionUtils
|
from src.common.utils.utils_session import SessionUtils
|
||||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
# from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.core.announcement_manager import global_announcement_manager
|
from src.core.announcement_manager import global_announcement_manager
|
||||||
from src.core.component_registry import component_registry
|
from src.core.component_registry import component_registry
|
||||||
@@ -32,7 +32,7 @@ class ChatBot:
|
|||||||
self.bot = None # bot 实例引用
|
self.bot = None # bot 实例引用
|
||||||
self._started = False
|
self._started = False
|
||||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||||
self.pfc_manager = PFCManager.get_instance() # PFC管理器
|
# self.pfc_manager = PFCManager.get_instance() # PFC管理器 # TODO: PFC恢复
|
||||||
|
|
||||||
async def _ensure_started(self):
|
async def _ensure_started(self):
|
||||||
"""确保所有任务已启动"""
|
"""确保所有任务已启动"""
|
||||||
@@ -374,12 +374,14 @@ class ChatBot:
|
|||||||
# await preprocess()
|
# await preprocess()
|
||||||
async def preprocess():
|
async def preprocess():
|
||||||
if group_info is None:
|
if group_info is None:
|
||||||
logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
# logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||||
MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
# MessageUtils.store_message_to_db(message) # 存储消息到数据库
|
||||||
await self._create_pfc_chat(message)
|
# await self._create_pfc_chat(message)
|
||||||
|
logger.critical("暂时禁用私聊")
|
||||||
else:
|
else:
|
||||||
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
|
await preprocess()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"预处理消息失败: {e}")
|
logger.error(f"预处理消息失败: {e}")
|
||||||
|
|||||||
@@ -23,266 +23,266 @@ _webui_chat_broadcaster = None
|
|||||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||||
|
|
||||||
# TODO: 重构完成后完成webui相关
|
# TODO: 重构完成后完成webui相关
|
||||||
# def get_webui_chat_broadcaster():
|
def get_webui_chat_broadcaster():
|
||||||
# """获取 WebUI 聊天室广播器"""
|
"""获取 WebUI 聊天室广播器"""
|
||||||
# global _webui_chat_broadcaster
|
global _webui_chat_broadcaster
|
||||||
# if _webui_chat_broadcaster is None:
|
if _webui_chat_broadcaster is None:
|
||||||
# try:
|
try:
|
||||||
# from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
||||||
|
|
||||||
# _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||||
# except ImportError:
|
except ImportError:
|
||||||
# _webui_chat_broadcaster = (None, None)
|
_webui_chat_broadcaster = (None, None)
|
||||||
# return _webui_chat_broadcaster
|
return _webui_chat_broadcaster
|
||||||
|
|
||||||
|
|
||||||
# def is_webui_virtual_group(group_id: str) -> bool:
|
def is_webui_virtual_group(group_id: str) -> bool:
|
||||||
# """检查是否是 WebUI 虚拟群"""
|
"""检查是否是 WebUI 虚拟群"""
|
||||||
# return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
# def parse_message_segments(segment) -> list:
|
def parse_message_segments(segment) -> list:
|
||||||
# """解析消息段,转换为 WebUI 可用的格式
|
"""解析消息段,转换为 WebUI 可用的格式
|
||||||
|
|
||||||
# 参考 NapCat 适配器的消息解析逻辑
|
参考 NapCat 适配器的消息解析逻辑
|
||||||
|
|
||||||
# Args:
|
Args:
|
||||||
# segment: Seg 消息段对象
|
segment: Seg 消息段对象
|
||||||
|
|
||||||
# Returns:
|
Returns:
|
||||||
# list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
list: 消息段列表,每个元素为 {"type": "...", "data": ...}
|
||||||
# """
|
"""
|
||||||
|
|
||||||
# result = []
|
result = []
|
||||||
|
|
||||||
# if segment is None:
|
if segment is None:
|
||||||
# return result
|
return result
|
||||||
|
|
||||||
# if segment.type == "seglist":
|
if segment.type == "seglist":
|
||||||
# # 处理消息段列表
|
# 处理消息段列表
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# for seg in segment.data:
|
for seg in segment.data:
|
||||||
# result.extend(parse_message_segments(seg))
|
result.extend(parse_message_segments(seg))
|
||||||
# elif segment.type == "text":
|
elif segment.type == "text":
|
||||||
# # 文本消息
|
# 文本消息
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "text", "data": segment.data})
|
result.append({"type": "text", "data": segment.data})
|
||||||
# elif segment.type == "image":
|
elif segment.type == "image":
|
||||||
# # 图片消息(base64)
|
# 图片消息(base64)
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
|
||||||
# elif segment.type == "emoji":
|
elif segment.type == "emoji":
|
||||||
# # 表情包消息(base64)
|
# 表情包消息(base64)
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
|
||||||
# elif segment.type == "imageurl":
|
elif segment.type == "imageurl":
|
||||||
# # 图片链接消息
|
# 图片链接消息
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "image", "data": segment.data})
|
result.append({"type": "image", "data": segment.data})
|
||||||
# elif segment.type == "face":
|
elif segment.type == "face":
|
||||||
# # 原生表情
|
# 原生表情
|
||||||
# result.append({"type": "face", "data": segment.data})
|
result.append({"type": "face", "data": segment.data})
|
||||||
# elif segment.type == "voice":
|
elif segment.type == "voice":
|
||||||
# # 语音消息(base64)
|
# 语音消息(base64)
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
|
||||||
# elif segment.type == "voiceurl":
|
elif segment.type == "voiceurl":
|
||||||
# # 语音链接
|
# 语音链接
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "voice", "data": segment.data})
|
result.append({"type": "voice", "data": segment.data})
|
||||||
# elif segment.type == "video":
|
elif segment.type == "video":
|
||||||
# # 视频消息(base64)
|
# 视频消息(base64)
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
|
||||||
# elif segment.type == "videourl":
|
elif segment.type == "videourl":
|
||||||
# # 视频链接
|
# 视频链接
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "video", "data": segment.data})
|
result.append({"type": "video", "data": segment.data})
|
||||||
# elif segment.type == "music":
|
elif segment.type == "music":
|
||||||
# # 音乐消息
|
# 音乐消息
|
||||||
# result.append({"type": "music", "data": segment.data})
|
result.append({"type": "music", "data": segment.data})
|
||||||
# elif segment.type == "file":
|
elif segment.type == "file":
|
||||||
# # 文件消息
|
# 文件消息
|
||||||
# result.append({"type": "file", "data": segment.data})
|
result.append({"type": "file", "data": segment.data})
|
||||||
# elif segment.type == "reply":
|
elif segment.type == "reply":
|
||||||
# # 回复消息
|
# 回复消息
|
||||||
# result.append({"type": "reply", "data": segment.data})
|
result.append({"type": "reply", "data": segment.data})
|
||||||
# elif segment.type == "forward":
|
elif segment.type == "forward":
|
||||||
# # 转发消息
|
# 转发消息
|
||||||
# forward_items = []
|
forward_items = []
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# for item in segment.data:
|
for item in segment.data:
|
||||||
# forward_items.append(
|
forward_items.append(
|
||||||
# {
|
{
|
||||||
# "content": parse_message_segments(item.get("message_segment", {}))
|
"content": parse_message_segments(item.get("message_segment", {}))
|
||||||
# if isinstance(item, dict)
|
if isinstance(item, dict)
|
||||||
# else []
|
else []
|
||||||
# }
|
}
|
||||||
# )
|
)
|
||||||
# result.append({"type": "forward", "data": forward_items})
|
result.append({"type": "forward", "data": forward_items})
|
||||||
# else:
|
else:
|
||||||
# # 未知类型,尝试作为文本处理
|
# 未知类型,尝试作为文本处理
|
||||||
# if segment.data:
|
if segment.data:
|
||||||
# result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
|
||||||
|
|
||||||
# return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# async def _send_message(message: MessageSending, show_log=True) -> bool:
|
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
# """合并后的消息发送函数,包含WS发送和日志记录"""
|
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||||
# message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||||
# platform = message.platform
|
platform = message.platform
|
||||||
# group_id = message.session.group_id
|
group_id = message.session.group_id
|
||||||
|
|
||||||
# try:
|
try:
|
||||||
# # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||||
# chat_manager, webui_platform = get_webui_chat_broadcaster()
|
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||||
# is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
|
||||||
|
|
||||||
# if is_webui_message and chat_manager is not None:
|
if is_webui_message and chat_manager is not None:
|
||||||
# # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
|
||||||
# import time
|
import time
|
||||||
# from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
# # 解析消息段,获取富文本内容
|
# 解析消息段,获取富文本内容
|
||||||
# message_segments = parse_message_segments(message.message_segment)
|
message_segments = parse_message_segments(message.message_segment)
|
||||||
|
|
||||||
# # 判断消息类型
|
# 判断消息类型
|
||||||
# # 如果只有一个文本段,使用简单的 text 类型
|
# 如果只有一个文本段,使用简单的 text 类型
|
||||||
# # 否则使用 rich 类型,包含完整的消息段
|
# 否则使用 rich 类型,包含完整的消息段
|
||||||
# if len(message_segments) == 1 and message_segments[0].get("type") == "text":
|
if len(message_segments) == 1 and message_segments[0].get("type") == "text":
|
||||||
# message_type = "text"
|
message_type = "text"
|
||||||
# segments = None
|
segments = None
|
||||||
# else:
|
else:
|
||||||
# message_type = "rich"
|
message_type = "rich"
|
||||||
# segments = message_segments
|
segments = message_segments
|
||||||
|
|
||||||
# await chat_manager.broadcast(
|
await chat_manager.broadcast(
|
||||||
# {
|
{
|
||||||
# "type": "bot_message",
|
"type": "bot_message",
|
||||||
# "content": message.processed_plain_text,
|
"content": message.processed_plain_text,
|
||||||
# "message_type": message_type,
|
"message_type": message_type,
|
||||||
# "segments": segments, # 富文本消息段
|
"segments": segments, # 富文本消息段
|
||||||
# "timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
# "group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
|
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
|
||||||
# "sender": {
|
"sender": {
|
||||||
# "name": global_config.bot.nickname,
|
"name": global_config.bot.nickname,
|
||||||
# "avatar": None,
|
"avatar": None,
|
||||||
# "is_bot": True,
|
"is_bot": True,
|
||||||
# },
|
},
|
||||||
# }
|
}
|
||||||
# )
|
)
|
||||||
|
|
||||||
# # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||||
# # 无需手动保存
|
# 无需手动保存
|
||||||
|
|
||||||
# if show_log:
|
if show_log:
|
||||||
# if is_webui_virtual_group(group_id):
|
if is_webui_virtual_group(group_id):
|
||||||
# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})")
|
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})")
|
||||||
# else:
|
else:
|
||||||
# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||||
# return True
|
return True
|
||||||
|
|
||||||
# # Fallback 逻辑: 尝试通过 API Server 发送
|
# Fallback 逻辑: 尝试通过 API Server 发送
|
||||||
# async def send_with_new_api(legacy_exception=None):
|
async def send_with_new_api(legacy_exception=None):
|
||||||
# try:
|
try:
|
||||||
# from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
# # 如果未开启 API Server,直接跳过 Fallback
|
# 如果未开启 API Server,直接跳过 Fallback
|
||||||
# if not global_config.maim_message.enable_api_server:
|
if not global_config.maim_message.enable_api_server:
|
||||||
# logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
logger.debug("[API Server Fallback] API Server未开启,跳过fallback")
|
||||||
# if legacy_exception:
|
if legacy_exception:
|
||||||
# raise legacy_exception
|
raise legacy_exception
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# global_api = get_global_api()
|
global_api = get_global_api()
|
||||||
# extra_server = getattr(global_api, "extra_server", None)
|
extra_server = getattr(global_api, "extra_server", None)
|
||||||
|
|
||||||
# if not extra_server:
|
if not extra_server:
|
||||||
# logger.warning("[API Server Fallback] extra_server不存在")
|
logger.warning("[API Server Fallback] extra_server不存在")
|
||||||
# if legacy_exception:
|
if legacy_exception:
|
||||||
# raise legacy_exception
|
raise legacy_exception
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# if not extra_server.is_running():
|
if not extra_server.is_running():
|
||||||
# logger.warning("[API Server Fallback] extra_server未运行")
|
logger.warning("[API Server Fallback] extra_server未运行")
|
||||||
# if legacy_exception:
|
if legacy_exception:
|
||||||
# raise legacy_exception
|
raise legacy_exception
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# # Fallback: 使用极其简单的 Platform -> API Key 映射
|
# Fallback: 使用极其简单的 Platform -> API Key 映射
|
||||||
# # 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息
|
# 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息
|
||||||
# platform_map = getattr(global_api, "platform_map", {})
|
platform_map = getattr(global_api, "platform_map", {})
|
||||||
# logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'")
|
logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'")
|
||||||
# target_api_key = platform_map.get(platform)
|
target_api_key = platform_map.get(platform)
|
||||||
|
|
||||||
# if not target_api_key:
|
if not target_api_key:
|
||||||
# logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射")
|
logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射")
|
||||||
# if legacy_exception:
|
if legacy_exception:
|
||||||
# raise legacy_exception
|
raise legacy_exception
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# # 使用 MessageConverter 转换为 API 消息
|
# 使用 MessageConverter 转换为 API 消息
|
||||||
# from maim_message import MessageConverter
|
from maim_message import MessageConverter
|
||||||
|
|
||||||
# # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
|
# 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
|
||||||
# message_base = await message.to_maim_message()
|
message_base = await message.to_maim_message()
|
||||||
|
|
||||||
# api_message = MessageConverter.to_api_send(
|
api_message = MessageConverter.to_api_send(
|
||||||
# message=message_base,
|
message=message_base,
|
||||||
# api_key=target_api_key,
|
api_key=target_api_key,
|
||||||
# platform=platform,
|
platform=platform,
|
||||||
# )
|
)
|
||||||
|
|
||||||
# # 直接调用 Server 的 send_message 接口,它会自动处理路由
|
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||||
# logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
logger.debug("[API Server Fallback] 正在通过extra_server发送消息...")
|
||||||
# results = await extra_server.send_message(api_message)
|
results = await extra_server.send_message(api_message)
|
||||||
# logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
logger.debug(f"[API Server Fallback] 发送结果: {results}")
|
||||||
|
|
||||||
# # 检查是否有任何连接发送成功
|
# 检查是否有任何连接发送成功
|
||||||
# if any(results.values()):
|
if any(results.values()):
|
||||||
# if show_log:
|
if show_log:
|
||||||
# logger.info(
|
logger.info(
|
||||||
# f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
|
f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
|
||||||
# )
|
)
|
||||||
# return True
|
return True
|
||||||
# else:
|
else:
|
||||||
# logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}")
|
logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}")
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"[API Server Fallback] 发生异常: {e}")
|
logger.error(f"[API Server Fallback] 发生异常: {e}")
|
||||||
# import traceback
|
import traceback
|
||||||
|
|
||||||
# logger.debug(traceback.format_exc())
|
logger.debug(traceback.format_exc())
|
||||||
|
|
||||||
# # 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
|
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
|
||||||
# if legacy_exception:
|
if legacy_exception:
|
||||||
# raise legacy_exception
|
raise legacy_exception
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# try:
|
try:
|
||||||
# message_base = await message.to_maim_message()
|
message_base = await message.to_maim_message()
|
||||||
# send_result = await get_global_api().send_message(message_base)
|
send_result = await get_global_api().send_message(message_base)
|
||||||
# if send_result:
|
if send_result:
|
||||||
# if show_log:
|
if show_log:
|
||||||
# logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
||||||
# return True
|
return True
|
||||||
# else:
|
else:
|
||||||
# # Legacy API 返回 False (发送失败但未报错),尝试 Fallback
|
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
|
||||||
# fallback_result = await send_with_new_api()
|
fallback_result = await send_with_new_api()
|
||||||
# if fallback_result and show_log:
|
if fallback_result and show_log:
|
||||||
# # Fallback成功的日志已在send_with_new_api中打印
|
# Fallback成功的日志已在send_with_new_api中打印
|
||||||
# pass
|
pass
|
||||||
# return fallback_result
|
return fallback_result
|
||||||
|
|
||||||
# except Exception as legacy_e:
|
except Exception as legacy_e:
|
||||||
# # Legacy API 抛出异常,尝试 Fallback
|
# Legacy API 抛出异常,尝试 Fallback
|
||||||
# # 如果 Fallback 也失败,将重新抛出 legacy_e
|
# 如果 Fallback 也失败,将重新抛出 legacy_e
|
||||||
# return await send_with_new_api(legacy_exception=legacy_e)
|
return await send_with_new_api(legacy_exception=legacy_e)
|
||||||
|
|
||||||
# except Exception as e:
|
except Exception as e:
|
||||||
# logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
|
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
|
||||||
# traceback.print_exc()
|
traceback.print_exc()
|
||||||
# raise e # 重新抛出其他异常
|
raise e # 重新抛出其他异常
|
||||||
|
|
||||||
|
|
||||||
class UniversalMessageSender:
|
class UniversalMessageSender:
|
||||||
@@ -363,9 +363,9 @@ class UniversalMessageSender:
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(typing_time)
|
await asyncio.sleep(typing_time)
|
||||||
|
|
||||||
# sent_msg = await _send_message(message, show_log=show_log)
|
sent_msg = await _send_message(message, show_log=show_log)
|
||||||
# if not sent_msg:
|
if not sent_msg:
|
||||||
# return False
|
return False
|
||||||
|
|
||||||
# _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
|
# _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id)
|
||||||
# continue_flag, modified_message = await event_bus.emit(EventType.AFTER_SEND, _event_msg)
|
# continue_flag, modified_message = await event_bus.emit(EventType.AFTER_SEND, _event_msg)
|
||||||
|
|||||||
@@ -817,6 +817,10 @@ def assign_message_ids(messages: List[SessionMessage]) -> List[Tuple[str, Sessio
|
|||||||
result.append((message_id, message))
|
result.append((message_id, message))
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
# break
|
||||||
|
# result.append((message_id, message))
|
||||||
|
|
||||||
|
# return result
|
||||||
|
|
||||||
|
|
||||||
def parse_keywords_string(keywords_input) -> list[str]:
|
def parse_keywords_string(keywords_input) -> list[str]:
|
||||||
|
|||||||
@@ -39,11 +39,11 @@ class MessageInfo:
|
|||||||
|
|
||||||
|
|
||||||
class MaiMessage(BaseDatabaseDataModel[Messages]):
|
class MaiMessage(BaseDatabaseDataModel[Messages]):
|
||||||
def __init__(self, message_id: str, timestamp: datetime):
|
def __init__(self, message_id: str, timestamp: datetime, platform: str):
|
||||||
self.message_id: str = message_id
|
self.message_id: str = message_id
|
||||||
self.timestamp: datetime = timestamp # 时间戳
|
self.timestamp: datetime = timestamp # 时间戳
|
||||||
self.initialized = False # 用于标记是否已初始化其他属性
|
self.initialized = False # 用于标记是否已初始化其他属性
|
||||||
self.platform: str # 初始化后赋值
|
self.platform: str = platform
|
||||||
|
|
||||||
# 定义其他属性
|
# 定义其他属性
|
||||||
self.message_info: MessageInfo # 初始化后赋值
|
self.message_info: MessageInfo # 初始化后赋值
|
||||||
@@ -72,7 +72,7 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_db_instance(cls, db_record: "Messages"):
|
def from_db_instance(cls, db_record: "Messages"):
|
||||||
obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp)
|
obj = cls(message_id=db_record.message_id, timestamp=db_record.timestamp, platform=db_record.platform)
|
||||||
|
|
||||||
user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname)
|
user_info = UserInfo(db_record.user_id, db_record.user_nickname, db_record.user_cardname)
|
||||||
if db_record.group_id and db_record.group_name:
|
if db_record.group_id and db_record.group_name:
|
||||||
@@ -130,12 +130,14 @@ class MaiMessage(BaseDatabaseDataModel[Messages]):
|
|||||||
"""从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息"""
|
"""从 maim_message.MessageBase 创建 MaiMessage 实例,解析消息内容并提取相关信息"""
|
||||||
msg_info = message.message_info
|
msg_info = message.message_info
|
||||||
assert msg_info, "MessageBase 的 message_info 不能为空"
|
assert msg_info, "MessageBase 的 message_info 不能为空"
|
||||||
msg_id = msg_info.message_id
|
platform = msg_info.platform
|
||||||
|
assert isinstance(platform, str)
|
||||||
|
msg_id = str(msg_info.message_id)
|
||||||
timestamp = msg_info.time
|
timestamp = msg_info.time
|
||||||
assert isinstance(msg_id, str)
|
assert isinstance(msg_id, str)
|
||||||
assert msg_id
|
assert msg_id
|
||||||
assert timestamp
|
assert timestamp
|
||||||
obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp))
|
obj = cls(message_id=msg_id, timestamp=datetime.fromtimestamp(timestamp), platform=platform)
|
||||||
obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message)
|
obj.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(message)
|
||||||
usr_info = msg_info.user_info
|
usr_info = msg_info.user_info
|
||||||
assert usr_info
|
assert usr_info
|
||||||
|
|||||||
217
src/common/message_server/universal_message_sender.py
Normal file
217
src/common/message_server/universal_message_sender.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.message_server.api import get_global_api
|
||||||
|
from src.common.utils.math_utils import calculate_typing_time
|
||||||
|
from src.common.data_models.message_component_data_model import ReplyComponent
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.message import SessionMessage
|
||||||
|
|
||||||
|
logger = get_logger("sender")
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalMessageSender:
|
||||||
|
@staticmethod
|
||||||
|
async def send_message(
|
||||||
|
message: "SessionMessage",
|
||||||
|
*,
|
||||||
|
typing: bool = False,
|
||||||
|
storage_message: bool = True,
|
||||||
|
reply_message_id: Optional[str] = None,
|
||||||
|
show_log: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
处理、发送并存储一条消息。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: SessionMessage 对象,待发送的消息。
|
||||||
|
typing: 是否模拟打字等待。
|
||||||
|
storage_message: 是否存储消息到数据库。
|
||||||
|
reply_message_id: 回复消息的 ID。
|
||||||
|
show_log: 是否显示日志。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 消息是否发送成功。
|
||||||
|
"""
|
||||||
|
if not message.message_id:
|
||||||
|
logger.error("消息缺少 message_id,无法发送")
|
||||||
|
raise ValueError("消息缺少 message_id,无法发送")
|
||||||
|
|
||||||
|
# 设置回复
|
||||||
|
if reply_message_id:
|
||||||
|
message.raw_message.components.insert(0, ReplyComponent(reply_message_id))
|
||||||
|
|
||||||
|
# 处理消息
|
||||||
|
await message.process()
|
||||||
|
|
||||||
|
# 模拟打字等待
|
||||||
|
if typing:
|
||||||
|
typing_time = calculate_typing_time(message.processed_plain_text or "")
|
||||||
|
await asyncio.sleep(typing_time)
|
||||||
|
|
||||||
|
# 广播消息到插件
|
||||||
|
await UniversalMessageSender._broadcast_message_to_plugins(message)
|
||||||
|
|
||||||
|
# 发送消息
|
||||||
|
sent_result = await UniversalMessageSender._send_message_via_maim_message(message, show_log=show_log)
|
||||||
|
if not sent_result:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 存储消息到数据库
|
||||||
|
try:
|
||||||
|
if storage_message:
|
||||||
|
with get_db_session() as db_session:
|
||||||
|
db_session.add(message.to_db_instance())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[{message.session_id}] 存储消息 {message.message_id} 时出错:{e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _broadcast_message_to_plugins(message: "SessionMessage"):
|
||||||
|
"""广播消息到所有注册的插件"""
|
||||||
|
# TODO: 实现消息广播逻辑
|
||||||
|
raise NotImplementedError("消息广播到插件的功能尚未实现")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_message_via_maim_message(message: "SessionMessage", show_log: bool = True) -> bool:
|
||||||
|
"""
|
||||||
|
通过 MAIM Message API 发送消息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: SessionMessage 对象
|
||||||
|
show_log: 是否显示日志
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 消息是否发送成功
|
||||||
|
"""
|
||||||
|
# TODO: 重构至新的发送模型
|
||||||
|
message_preview = (message.processed_plain_text or "")[:200]
|
||||||
|
platform = message.platform
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试通过主 API 发送
|
||||||
|
try:
|
||||||
|
message_base = await message.to_maim_message()
|
||||||
|
send_result = await get_global_api().send_message(message_base)
|
||||||
|
if not send_result:
|
||||||
|
# Legacy API 返回 False,尝试 Fallback
|
||||||
|
# return await self._send_with_fallback(message, message_preview, platform, show_log)
|
||||||
|
return False
|
||||||
|
if show_log:
|
||||||
|
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as legacy_e:
|
||||||
|
# # Legacy API 抛出异常,尝试 Fallback
|
||||||
|
# return await self._send_with_fallback(
|
||||||
|
# message, message_preview, platform, show_log, legacy_exception=legacy_e
|
||||||
|
# )
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败:{str(e)}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def _send_with_fallback(
|
||||||
|
self,
|
||||||
|
message: "SessionMessage",
|
||||||
|
message_preview: str,
|
||||||
|
platform: str,
|
||||||
|
show_log: bool,
|
||||||
|
legacy_exception: Optional[Exception] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Fallback 发送逻辑:通过 API Server 发送
|
||||||
|
|
||||||
|
参数:
|
||||||
|
message: SessionMessage 对象
|
||||||
|
message_preview: 消息预览
|
||||||
|
platform: 目标平台
|
||||||
|
show_log: 是否显示日志
|
||||||
|
legacy_exception: 遗留异常(如果 Fallback 失败则抛出)
|
||||||
|
|
||||||
|
返回:
|
||||||
|
bool: 消息是否发送成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
# 如果未开启 API Server,直接跳过 Fallback
|
||||||
|
if not global_config.maim_message.enable_api_server:
|
||||||
|
logger.debug("[API Server Fallback] API Server 未开启,跳过 fallback")
|
||||||
|
if legacy_exception:
|
||||||
|
raise legacy_exception
|
||||||
|
return False
|
||||||
|
|
||||||
|
global_api = get_global_api()
|
||||||
|
extra_server = getattr(global_api, "extra_server", None)
|
||||||
|
|
||||||
|
if not extra_server:
|
||||||
|
logger.warning("[API Server Fallback] extra_server 不存在")
|
||||||
|
if legacy_exception:
|
||||||
|
raise legacy_exception
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not extra_server.is_running():
|
||||||
|
logger.warning("[API Server Fallback] extra_server 未运行")
|
||||||
|
if legacy_exception:
|
||||||
|
raise legacy_exception
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Fallback: 使用 Platform -> API Key 映射
|
||||||
|
platform_map = getattr(global_api, "platform_map", {})
|
||||||
|
logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台:'{platform}'")
|
||||||
|
target_api_key = platform_map.get(platform)
|
||||||
|
|
||||||
|
if not target_api_key:
|
||||||
|
logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的 API Key 映射")
|
||||||
|
if legacy_exception:
|
||||||
|
raise legacy_exception
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 使用 MessageConverter 转换为 API 消息
|
||||||
|
from maim_message import MessageConverter
|
||||||
|
|
||||||
|
# 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
|
||||||
|
message_base = await message.to_maim_message()
|
||||||
|
|
||||||
|
api_message = MessageConverter.to_api_send(
|
||||||
|
message=message_base,
|
||||||
|
api_key=target_api_key,
|
||||||
|
platform=platform,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 直接调用 Server 的 send_message 接口,它会自动处理路由
|
||||||
|
logger.debug("[API Server Fallback] 正在通过 extra_server 发送消息...")
|
||||||
|
results = await extra_server.send_message(api_message)
|
||||||
|
logger.debug(f"[API Server Fallback] 发送结果:{results}")
|
||||||
|
|
||||||
|
# 检查是否有任何连接发送成功
|
||||||
|
if any(results.values()):
|
||||||
|
if show_log:
|
||||||
|
logger.info(
|
||||||
|
f"已通过 API Server Fallback 将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning(f"[API Server Fallback] 没有连接发送成功,results={results}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[API Server Fallback] 发生异常:{e}")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
logger.debug(traceback.format_exc())
|
||||||
|
|
||||||
|
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
|
||||||
|
if legacy_exception:
|
||||||
|
raise legacy_exception
|
||||||
|
return False
|
||||||
@@ -80,3 +80,43 @@ def translate_timestamp_to_human_readable(timestamp: float, mode: TimestampMode
|
|||||||
return time.strftime(TimestampMode.NORMAL.value, time.localtime(timestamp))
|
return time.strftime(TimestampMode.NORMAL.value, time.localtime(timestamp))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"不支持的时间戳转换模式: {mode}")
|
raise ValueError(f"不支持的时间戳转换模式: {mode}")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_typing_time(
|
||||||
|
input_string: str,
|
||||||
|
chinese_time: float = 0.3,
|
||||||
|
english_time: float = 0.15,
|
||||||
|
line_break_time: float = 0.1,
|
||||||
|
is_emoji: bool = False,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||||
|
input_string (str): 输入的字符串
|
||||||
|
chinese_time (float): 中文字符的输入时间,默认为0.3秒
|
||||||
|
english_time (float): 英文字符的输入时间,默认为0.15秒
|
||||||
|
line_break_time (float): 换行符的输入时间,默认为0.1秒
|
||||||
|
is_emoji (bool): 是否为emoji,默认为False
|
||||||
|
|
||||||
|
特殊情况:
|
||||||
|
- 如果只有一个中文字符,将使用3倍的中文输入时间
|
||||||
|
- 在所有输入结束后,额外加上回车时间
|
||||||
|
- 如果is_emoji为True,将使用固定1秒的输入时间
|
||||||
|
"""
|
||||||
|
if is_emoji:
|
||||||
|
return 1.0 # 固定1秒的输入时间
|
||||||
|
|
||||||
|
# 正常计算所有字符的输入时间
|
||||||
|
total_time = 0.0
|
||||||
|
chinese_chars = 0
|
||||||
|
for char in input_string:
|
||||||
|
if "\u4e00" <= char <= "\u9fff":
|
||||||
|
total_time += chinese_time
|
||||||
|
chinese_chars += 1
|
||||||
|
else:
|
||||||
|
total_time += english_time
|
||||||
|
|
||||||
|
if chinese_chars == 1 and len(input_string.strip()) == 1:
|
||||||
|
# 如果只有一个中文字符,使用3倍时间
|
||||||
|
return chinese_time * 3 + line_break_time # 加上回车时间
|
||||||
|
|
||||||
|
return total_time + line_break_time # 加上回车时间
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ from maim_message import Seg
|
|||||||
|
|
||||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||||
from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType
|
# from src.common.data_models.message_data_model import ReplyContentType as ReplyContentType
|
||||||
from src.common.data_models.message_data_model import ReplyContent as ReplyContent
|
# from src.common.data_models.message_data_model import ReplyContent as ReplyContent
|
||||||
from src.common.data_models.message_data_model import ForwardNode as ForwardNode
|
# from src.common.data_models.message_data_model import ForwardNode as ForwardNode
|
||||||
from src.common.data_models.message_data_model import ReplySetModel as ReplySetModel
|
# from src.common.data_models.message_data_model import ReplySetModel as ReplySetModel
|
||||||
|
|
||||||
|
|
||||||
# 组件类型枚举
|
# 组件类型枚举
|
||||||
|
|||||||
@@ -52,24 +52,24 @@ def is_person_known(
|
|||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
platform: Optional[str] = None,
|
platform: Optional[str] = None,
|
||||||
person_name: Optional[str] = None,
|
person_name: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool: # sourcery skip: extract-duplicate-method
|
||||||
if person_id:
|
if person_id:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
person = session.exec(statement).first()
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
elif user_id and platform:
|
elif user_id and platform:
|
||||||
person_id = get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
person = session.exec(statement).first()
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
elif person_name:
|
elif person_name:
|
||||||
person_id = get_person_id_by_person_name(person_name)
|
person_id = get_person_id_by_person_name(person_name)
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
person = session.exec(statement).first()
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -462,49 +462,49 @@ class Person:
|
|||||||
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
||||||
record = session.exec(statement).first()
|
record = session.exec(statement).first()
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
self.user_id = record.user_id or ""
|
self.user_id = record.user_id or ""
|
||||||
self.platform = record.platform or ""
|
self.platform = record.platform or ""
|
||||||
self.is_known = record.is_known or False
|
self.is_known = record.is_known or False
|
||||||
self.nickname = record.user_nickname or ""
|
self.nickname = record.user_nickname or ""
|
||||||
self.person_name = record.person_name or self.nickname
|
self.person_name = record.person_name or self.nickname
|
||||||
self.name_reason = record.name_reason or None
|
self.name_reason = record.name_reason or None
|
||||||
self.know_times = record.know_counts or 0
|
self.know_times = record.know_counts or 0
|
||||||
|
|
||||||
# 处理points字段(JSON格式的列表)
|
# 处理points字段(JSON格式的列表)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
loaded_points = json.loads(record.memory_points)
|
loaded_points = json.loads(record.memory_points)
|
||||||
# 过滤掉None值,确保数据质量
|
# 过滤掉None值,确保数据质量
|
||||||
if isinstance(loaded_points, list):
|
if isinstance(loaded_points, list):
|
||||||
self.memory_points = [point for point in loaded_points if point is not None]
|
self.memory_points = [point for point in loaded_points if point is not None]
|
||||||
else:
|
else:
|
||||||
|
self.memory_points = []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
logger.warning(f"解析用户 {self.person_id} 的points字段失败,使用默认值")
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
except (json.JSONDecodeError, TypeError):
|
else:
|
||||||
logger.warning(f"解析用户 {self.person_id} 的points字段失败,使用默认值")
|
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
else:
|
|
||||||
self.memory_points = []
|
|
||||||
|
|
||||||
# 处理group_nick_name字段(JSON格式的列表)
|
# 处理group_nick_name字段(JSON格式的列表)
|
||||||
if record.group_nickname:
|
if record.group_cardname:
|
||||||
try:
|
try:
|
||||||
loaded_group_nick_names = json.loads(record.group_nickname)
|
loaded_group_nick_names = json.loads(record.group_cardname)
|
||||||
# 确保是列表格式
|
# 确保是列表格式
|
||||||
if isinstance(loaded_group_nick_names, list):
|
if isinstance(loaded_group_nick_names, list):
|
||||||
self.group_nick_name = loaded_group_nick_names
|
self.group_nick_name = loaded_group_nick_names
|
||||||
else:
|
else:
|
||||||
|
self.group_nick_name = []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
logger.warning(f"解析用户 {self.person_id} 的group_cardname字段失败,使用默认值")
|
||||||
self.group_nick_name = []
|
self.group_nick_name = []
|
||||||
except (json.JSONDecodeError, TypeError):
|
else:
|
||||||
logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败,使用默认值")
|
|
||||||
self.group_nick_name = []
|
self.group_nick_name = []
|
||||||
else:
|
|
||||||
self.group_nick_name = []
|
|
||||||
|
|
||||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||||
else:
|
else:
|
||||||
self.sync_to_database()
|
self.sync_to_database()
|
||||||
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
|
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
|
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user