把字典转换为数据模型并恢复全系统可用性,临时修复InstantMemory让大模型至少知道在聊什么

This commit is contained in:
UnCLAS-Prommer
2025-08-21 23:21:56 +08:00
parent f41a3076f6
commit c6f0c51825
12 changed files with 462 additions and 258 deletions

View File

@@ -9,7 +9,7 @@
"""
import traceback
from typing import Tuple, Any, Dict, List, Optional
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer
@@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages
install(extra_lines=3)
logger = get_logger("generator_api")
@@ -73,11 +77,11 @@ async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
@@ -85,7 +89,7 @@ async def generate_reply(
request_type: str = "generator_api",
from_plugin: bool = True,
return_expressions: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]:
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
"""生成回复
Args:
@@ -96,7 +100,7 @@ async def generate_reply(
extra_info: 额外信息,用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
@@ -110,16 +114,14 @@ async def generate_reply(
try:
# 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(
chat_stream, chat_id, request_type=request_type
)
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
return False, [], None, None
if not extra_info and action_data:
extra_info = action_data.get("extra_info", "")
if not reply_reason and action_data:
reply_reason = action_data.get("reason", "")
@@ -127,7 +129,7 @@ async def generate_reply(
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
chosen_actions=choosen_actions,
chosen_actions=chosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
@@ -136,7 +138,7 @@ async def generate_reply(
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
return False, [], None
return False, [], None, None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
@@ -144,17 +146,23 @@ async def generate_reply(
reply_set = []
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
if return_prompt:
if return_expressions:
return success, reply_set, (prompt, selected_expressions)
else:
return success, reply_set, prompt
else:
if return_expressions:
return success, reply_set, (None, selected_expressions)
else:
return success, reply_set, None
# if return_prompt:
# if return_expressions:
# return success, reply_set, prompt, selected_expressions
# else:
# return success, reply_set, prompt, None
# else:
# if return_expressions:
# return success, reply_set, (None, selected_expressions)
# else:
# return success, reply_set, None
return (
success,
reply_set,
prompt if return_prompt else None,
selected_expressions if return_expressions else None,
)
except ValueError as ve:
raise ve

View File

@@ -21,15 +21,17 @@
import traceback
import time
from typing import Optional, Union, Dict, Any, List
from src.common.logger import get_logger
from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING
# 导入依赖
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg, UserInfo
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("send_api")
@@ -46,10 +48,10 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions:List[int] = None,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定目标发送消息的内部实现
@@ -70,7 +72,7 @@ async def _send_to_target(
if set_reply and not reply_message:
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
return False
if show_log:
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
@@ -98,13 +100,13 @@ async def _send_to_target(
message_segment = Seg(type=message_type, data=content) # type: ignore
if reply_message:
anchor_message = message_dict_to_message_recv(reply_message)
anchor_message = message_dict_to_message_recv(reply_message.flatten())
if anchor_message:
anchor_message.update_chat_stream(target_stream)
assert anchor_message.message_info.user_info, "用户信息缺失"
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
)
)
else:
reply_to_platform_id = ""
anchor_message = None
@@ -192,12 +194,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
}
message_recv = MessageRecv(message_dict_recv)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
return message_recv
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# =============================================================================
@@ -208,9 +209,9 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
selected_expressions:List[int] = None,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定流发送文本消息
@@ -237,7 +238,13 @@ async def text_to_stream(
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def emoji_to_stream(
emoji_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送表情包
Args:
@@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
return await _send_to_target(
"emoji",
emoji_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def image_to_stream(
image_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送图片
Args:
@@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
return await _send_to_target(
"image",
image_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def command_to_stream(
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
command: Union[str, dict],
stream_id: str,
storage_message: bool = True,
display_message: str = "",
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送命令
@@ -279,7 +315,14 @@ async def command_to_stream(
bool: 是否发送成功
"""
return await _send_to_target(
"command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message
"command",
command,
stream_id,
display_message,
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
@@ -289,7 +332,7 @@ async def custom_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,

View File

@@ -2,13 +2,15 @@ import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any
from typing import Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action")
@@ -206,7 +208,11 @@ class BaseAction(ABC):
return False, f"等待新消息失败: {str(e)}"
async def send_text(
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
) -> bool:
"""发送文本消息
@@ -229,7 +235,9 @@ class BaseAction(ABC):
typing=typing,
)
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包
Args:
@@ -242,9 +250,13 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.emoji_to_stream(
emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_image(
self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送图片
Args:
@@ -257,9 +269,18 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.image_to_stream(
image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_custom(
self,
message_type: str,
content: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送自定义类型消息
Args:
@@ -308,7 +329,13 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送命令消息

View File

@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, Any
from typing import Dict, Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_command")
@@ -84,7 +87,13 @@ class BaseCommand(ABC):
return current
async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送回复消息
Args:
@@ -100,10 +109,22 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
return await send_api.text_to_stream(
text=content,
stream_id=chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
message_type: str,
content: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送指定类型的回复消息到当前聊天环境
@@ -134,7 +155,13 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送命令消息
@@ -177,7 +204,9 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包
Args:
@@ -191,9 +220,17 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.emoji_to_stream(
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
@@ -207,7 +244,13 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
return await send_api.image_to_stream(
image_base64,
chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
@classmethod
def get_command_info(cls) -> "CommandInfo":