添加更多种类的发送类型
This commit is contained in:
@@ -12,6 +12,7 @@ import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
@@ -138,12 +139,11 @@ async def generate_reply(
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
return False, None
|
||||
reply_set: Optional[ReplySetModel] = None
|
||||
if content := llm_response.content:
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
else:
|
||||
reply_set = []
|
||||
llm_response.reply_set = reply_set
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||
|
||||
return success, llm_response
|
||||
|
||||
@@ -159,6 +159,7 @@ async def generate_reply(
|
||||
logger.error(traceback.format_exc())
|
||||
return False, None
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
@@ -208,12 +209,12 @@ async def rewrite_reply(
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
reply_set = []
|
||||
reply_set: Optional[ReplySetModel] = None
|
||||
if success and llm_response and (content := llm_response.content):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
llm_response.reply_set = reply_set
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
@@ -227,7 +228,7 @@ async def rewrite_reply(
|
||||
return False, None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
@@ -238,18 +239,17 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
reply_set = ReplySetModel()
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
reply_set = []
|
||||
for text in processed_response:
|
||||
reply_seg = ("text", text)
|
||||
reply_set.append(reply_seg)
|
||||
reply_set.add_text_content(text)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
||||
return []
|
||||
return None
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
|
||||
@@ -21,17 +21,19 @@
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING
|
||||
from typing import Optional, Union, Dict, List, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_data_model import ReplySetModel
|
||||
|
||||
logger = get_logger("send_api")
|
||||
|
||||
@@ -42,8 +44,7 @@ logger = get_logger("send_api")
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: Union[str, dict],
|
||||
message_segment: Seg,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
@@ -56,8 +57,7 @@ async def _send_to_target(
|
||||
"""向指定目标发送消息的内部实现
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
message_segment:
|
||||
stream_id: 目标流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否模拟打字等待。
|
||||
@@ -74,7 +74,7 @@ async def _send_to_target(
|
||||
return False
|
||||
|
||||
if show_log:
|
||||
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
||||
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}")
|
||||
|
||||
# 查找目标聊天流
|
||||
target_stream = get_chat_manager().get_stream(stream_id)
|
||||
@@ -83,7 +83,7 @@ async def _send_to_target(
|
||||
return False
|
||||
|
||||
# 创建发送器
|
||||
heart_fc_sender = HeartFCSender()
|
||||
message_sender = UniversalMessageSender()
|
||||
|
||||
# 生成消息ID
|
||||
current_time = time.time()
|
||||
@@ -96,13 +96,11 @@ async def _send_to_target(
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content) # type: ignore
|
||||
|
||||
reply_to_platform_id = ""
|
||||
anchor_message: Union["MessageRecv", None] = None
|
||||
if reply_message:
|
||||
anchor_message = message_dict_to_message_recv(reply_message.flatten())
|
||||
anchor_message = db_message_to_message_recv(reply_message)
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") # type: ignore
|
||||
if anchor_message:
|
||||
anchor_message.update_chat_stream(target_stream)
|
||||
assert anchor_message.message_info.user_info, "用户信息缺失"
|
||||
@@ -120,14 +118,14 @@ async def _send_to_target(
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
is_emoji=(message_segment.type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
reply_to=reply_to_platform_id,
|
||||
selected_expressions=selected_expressions,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
sent_msg = await message_sender.send_message(
|
||||
bot_message,
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
@@ -148,7 +146,7 @@ async def _send_to_target(
|
||||
return False
|
||||
|
||||
|
||||
def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[MessageRecv]:
|
||||
def db_message_to_message_recv(message_obj: "DatabaseMessages") -> MessageRecv:
|
||||
"""将数据库dict重建为MessageRecv对象
|
||||
Args:
|
||||
message_dict: 消息字典
|
||||
@@ -158,44 +156,41 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
||||
"""
|
||||
# 构建MessageRecv对象
|
||||
user_info = {
|
||||
"platform": message_dict.get("user_platform", ""),
|
||||
"user_id": message_dict.get("user_id", ""),
|
||||
"user_nickname": message_dict.get("user_nickname", ""),
|
||||
"user_cardname": message_dict.get("user_cardname", ""),
|
||||
"platform": message_obj.user_info.platform or "",
|
||||
"user_id": message_obj.user_info.user_id or "",
|
||||
"user_nickname": message_obj.user_info.user_nickname or "",
|
||||
"user_cardname": message_obj.user_info.user_cardname or "",
|
||||
}
|
||||
|
||||
group_info = {}
|
||||
if message_dict.get("chat_info_group_id"):
|
||||
if message_obj.chat_info.group_info:
|
||||
group_info = {
|
||||
"platform": message_dict.get("chat_info_group_platform", ""),
|
||||
"group_id": message_dict.get("chat_info_group_id", ""),
|
||||
"group_name": message_dict.get("chat_info_group_name", ""),
|
||||
"platform": message_obj.chat_info.group_info.group_platform or "",
|
||||
"group_id": message_obj.chat_info.group_info.group_id or "",
|
||||
"group_name": message_obj.chat_info.group_info.group_name or "",
|
||||
}
|
||||
|
||||
format_info = {"content_format": "", "accept_format": ""}
|
||||
template_info = {"template_items": {}}
|
||||
|
||||
message_info = {
|
||||
"platform": message_dict.get("chat_info_platform", ""),
|
||||
"message_id": message_dict.get("message_id"),
|
||||
"time": message_dict.get("time"),
|
||||
"platform": message_obj.chat_info.platform or "",
|
||||
"message_id": message_obj.message_id,
|
||||
"time": message_obj.time,
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
"additional_config": message_dict.get("additional_config"),
|
||||
"additional_config": message_obj.additional_config,
|
||||
"format_info": format_info,
|
||||
"template_info": template_info,
|
||||
}
|
||||
|
||||
message_dict_recv = {
|
||||
"message_info": message_info,
|
||||
"raw_message": message_dict.get("processed_plain_text"),
|
||||
"processed_plain_text": message_dict.get("processed_plain_text"),
|
||||
"raw_message": message_obj.processed_plain_text,
|
||||
"processed_plain_text": message_obj.processed_plain_text,
|
||||
}
|
||||
|
||||
message_recv = MessageRecv(message_dict_recv)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||
return message_recv
|
||||
return MessageRecv(message_dict_recv)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -225,11 +220,10 @@ async def text_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"text",
|
||||
text,
|
||||
stream_id,
|
||||
"",
|
||||
typing,
|
||||
message_segment=Seg(type="text", data=text),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=typing,
|
||||
set_reply=set_reply,
|
||||
reply_message=reply_message,
|
||||
storage_message=storage_message,
|
||||
@@ -255,10 +249,9 @@ async def emoji_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"emoji",
|
||||
emoji_base64,
|
||||
stream_id,
|
||||
"",
|
||||
message_segment=Seg(type="emoji", data=emoji_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
@@ -284,10 +277,9 @@ async def image_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"image",
|
||||
image_base64,
|
||||
stream_id,
|
||||
"",
|
||||
message_segment=Seg(type="image", data=image_base64),
|
||||
stream_id=stream_id,
|
||||
display_message="",
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
@@ -314,10 +306,9 @@ async def command_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"command",
|
||||
command,
|
||||
stream_id,
|
||||
display_message,
|
||||
message_segment=Seg(type="command", data=command), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=False,
|
||||
storage_message=storage_message,
|
||||
set_reply=set_reply,
|
||||
@@ -327,7 +318,7 @@ async def command_to_stream(
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | dict,
|
||||
content: str | Dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
@@ -351,8 +342,7 @@ async def custom_to_stream(
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
message_segment=Seg(type=message_type, data=content), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
@@ -361,3 +351,105 @@ async def custom_to_stream(
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
|
||||
async def custom_reply_set_to_stream(
|
||||
reply_set: "ReplySetModel",
|
||||
stream_id: str,
|
||||
display_message: str = "", # 基本没用
|
||||
typing: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
set_reply: bool = False,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送混合型消息集"""
|
||||
flag: bool = True
|
||||
for reply_content in reply_set.reply_data:
|
||||
status: bool = False
|
||||
content_type = reply_content.content_type
|
||||
message_data = reply_content.content
|
||||
if content_type == ReplyContentType.TEXT:
|
||||
status = await _send_to_target(
|
||||
message_segment=Seg(type="text", data=message_data), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
elif content_type in [
|
||||
ReplyContentType.IMAGE,
|
||||
ReplyContentType.EMOJI,
|
||||
ReplyContentType.COMMAND,
|
||||
ReplyContentType.VOICE,
|
||||
]:
|
||||
message_segment: Seg
|
||||
if ReplyContentType == ReplyContentType.IMAGE:
|
||||
message_segment = Seg(type="image", data=message_data) # type: ignore
|
||||
elif ReplyContentType == ReplyContentType.EMOJI:
|
||||
message_segment = Seg(type="emoji", data=message_data) # type: ignore
|
||||
elif ReplyContentType == ReplyContentType.COMMAND:
|
||||
message_segment = Seg(type="command", data=message_data) # type: ignore
|
||||
elif ReplyContentType == ReplyContentType.VOICE:
|
||||
message_segment = Seg(type="voice", data=message_data) # type: ignore
|
||||
status = await _send_to_target(
|
||||
message_segment=message_segment,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=False,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
elif content_type == ReplyContentType.HYBRID:
|
||||
assert isinstance(message_data, list), "混合类型内容必须是列表"
|
||||
sub_seg_list: List[Seg] = []
|
||||
for sub_content in message_data:
|
||||
sub_content_type = sub_content.content_type
|
||||
sub_content_data = sub_content.content
|
||||
|
||||
if sub_content_type == ReplyContentType.TEXT:
|
||||
sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore
|
||||
elif sub_content_type == ReplyContentType.IMAGE:
|
||||
sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore
|
||||
elif sub_content_type == ReplyContentType.EMOJI:
|
||||
sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore
|
||||
else:
|
||||
logger.warning(f"[SendAPI] 混合类型中不支持的子内容类型: {repr(sub_content_type)}")
|
||||
continue
|
||||
status = await _send_to_target(
|
||||
message_segment=Seg(type="seglist", data=sub_seg_list), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
elif content_type == ReplyContentType.FORWARD:
|
||||
assert isinstance(message_data, list), "转发类型内容必须是列表"
|
||||
# TODO: 完成转发消息的发送机制
|
||||
else:
|
||||
message_type_in_str = (
|
||||
content_type.value if isinstance(content_type, ReplyContentType) else str(content_type)
|
||||
)
|
||||
return await _send_to_target(
|
||||
message_segment=Seg(type=message_type_in_str, data=message_data), # type: ignore
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_message=reply_message,
|
||||
set_reply=set_reply,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
if not status:
|
||||
flag = False
|
||||
logger.error(f"[SendAPI] 发送{repr(content_type)}消息失败,消息内容:{str(message_data)[:100]}")
|
||||
|
||||
return flag
|
||||
|
||||
@@ -2,7 +2,7 @@ import time
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, TYPE_CHECKING
|
||||
from typing import Tuple, Optional, TYPE_CHECKING, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -285,7 +285,7 @@ class BaseAction(ABC):
|
||||
async def send_custom(
|
||||
self,
|
||||
message_type: str,
|
||||
content: str,
|
||||
content: str | Dict,
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
reply_message: Optional["DatabaseMessages"] = None,
|
||||
|
||||
@@ -120,7 +120,7 @@ class BaseCommand(ABC):
|
||||
async def send_type(
|
||||
self,
|
||||
message_type: str,
|
||||
content: str,
|
||||
content: str | Dict,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
set_reply: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user