把字典转换为数据模型并恢复全系统可用性,临时修复InstantMemory让大模型至少知道在聊什么

This commit is contained in:
UnCLAS-Prommer
2025-08-21 23:21:56 +08:00
parent f41a3076f6
commit c6f0c51825
12 changed files with 462 additions and 258 deletions

View File

@@ -3,13 +3,13 @@ import time
import traceback import traceback
import math import math
import random import random
from typing import List, Optional, Dict, Any, Tuple from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install from rich.traceback import install
from collections import deque from collections import deque
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
@@ -24,12 +24,15 @@ from src.chat.frequency_control.focus_value_control import focus_value_control
from src.chat.express.expression_learner import expression_learner_manager from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
from src.plugin_system.core import events_manager from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.mai_think import mai_thinking_manager
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
ERROR_LOOP_INFO = { ERROR_LOOP_INFO = {
"loop_plan_info": { "loop_plan_info": {
@@ -141,7 +144,7 @@ class HeartFChatting:
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
def start_cycle(self): def start_cycle(self) -> Tuple[Dict[str, float], str]:
self._cycle_counter += 1 self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(self._cycle_counter) self._current_cycle_detail = CycleDetail(self._cycle_counter)
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
@@ -172,7 +175,8 @@ class HeartFChatting:
action_type = action_result.get("action_type", "未知动作") action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result: elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表 # 新格式action_result是actions列表
action_type = action_result[0].get("action_type", "未知动作") # TODO: 把这里写明白
action_type = action_result[0].action_type or "未知动作"
elif isinstance(loop_plan_info, list) and loop_plan_info: elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况 # 直接是actions列表的情况
action_type = loop_plan_info[0].get("action_type", "未知动作") action_type = loop_plan_info[0].get("action_type", "未知动作")
@@ -207,7 +211,7 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息") logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
self.focus_energy = 1 self.focus_energy = 1
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]: async def _should_process_messages(self, new_message: List["DatabaseMessages"]) -> tuple[bool, float]:
""" """
判断是否应该处理消息 判断是否应该处理消息
@@ -290,11 +294,11 @@ class HeartFChatting:
async def _send_and_store_reply( async def _send_and_store_reply(
self, self,
response_set, response_set,
action_message, action_message: "DatabaseMessages",
cycle_timers: Dict[str, float], cycle_timers: Dict[str, float],
thinking_id, thinking_id,
actions, actions,
selected_expressions: List[int] = None, selected_expressions: Optional[List[int]] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]: ) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers): with Timer("回复发送", cycle_timers):
reply_text = await self._send_response( reply_text = await self._send_response(
@@ -304,11 +308,11 @@ class HeartFChatting:
) )
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 # 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform") platform = action_message.chat_info.platform
if platform is None: if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown") platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform=platform, user_id=action_message.get("user_id", "")) person = Person(platform=platform, user_id=action_message.user_info.user_id)
person_name = person.person_name person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@@ -354,7 +358,11 @@ class HeartFChatting:
x0 = 1.0 # 控制曲线中心点 x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0))) return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency() normal_mode_probability = (
calculate_normal_mode_probability(interest_value)
* 2
* self.talk_frequency_control.get_current_talk_frequency()
)
# 根据概率决定使用哪种模式 # 根据概率决定使用哪种模式
if random.random() < normal_mode_probability: if random.random() < normal_mode_probability:
@@ -383,17 +391,17 @@ class HeartFChatting:
except Exception as e: except Exception as e:
logger.error(f"{self.log_prefix} 记忆构建失败: {e}") logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
available_actions: Dict[str, ActionInfo] = {}
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS: if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
# 如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考 # 如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
actions = [ action_to_use_info = [
{ ActionPlannerInfo(
"action_type": "no_action", action_type="no_action",
"reasoning": "专注不足", reasoning="专注不足",
"action_data": {}, action_data={},
} )
] ]
else: else:
available_actions = {}
# 第一步:动作修改 # 第一步:动作修改
with Timer("动作修改", cycle_timers): with Timer("动作修改", cycle_timers):
try: try:
@@ -414,105 +422,19 @@ class HeartFChatting:
): ):
return False return False
with Timer("规划器", cycle_timers): with Timer("规划器", cycle_timers):
actions, _ = await self.action_planner.plan( action_to_use_info, _ = await self.action_planner.plan(
mode=mode, mode=mode,
loop_start_time=self.last_read_time, loop_start_time=self.last_read_time,
available_actions=available_actions, available_actions=available_actions,
) )
# 3. 并行执行所有动作 # 3. 并行执行所有动作
async def execute_action(action_info, actions): action_tasks = [
"""执行单个动作的通用函数""" asyncio.create_task(
try: self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
if action_info["action_type"] == "no_action": )
# 直接处理no_action逻辑不再通过动作系统 for action in action_to_use_info
reason = action_info.get("reasoning", "选择不回复") ]
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_info["action_type"] != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_info["action_type"],
action_info["reasoning"],
action_info["action_data"],
cycle_timers,
thinking_id,
action_info["action_message"],
)
return {
"action_type": action_info["action_type"],
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try:
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_info["action_message"],
available_actions=available_actions,
choosen_actions=actions,
reply_reason=action_info.get("reasoning", ""),
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
return_expressions=True,
)
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
_, selected_expressions = prompt_selected_expressions
else:
selected_expressions = []
if not success or not response_set:
logger.info(
f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
)
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set=response_set,
action_message=action_info["action_message"],
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=actions,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_info["action_type"],
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e),
}
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
# 并行执行所有任务 # 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True) results = await asyncio.gather(*action_tasks, return_exceptions=True)
@@ -529,7 +451,7 @@ class HeartFChatting:
logger.error(f"{self.log_prefix} 动作执行异常: {result}") logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue continue
_cur_action = actions[i] _cur_action = action_to_use_info[i]
if result["action_type"] != "reply": if result["action_type"] != "reply":
action_success = result["success"] action_success = result["success"]
action_reply_text = result["reply_text"] action_reply_text = result["reply_text"]
@@ -558,7 +480,7 @@ class HeartFChatting:
# 没有回复信息构建纯动作的loop_info # 没有回复信息构建纯动作的loop_info
loop_info = { loop_info = {
"loop_plan_info": { "loop_plan_info": {
"action_result": actions, "action_result": action_to_use_info,
}, },
"loop_action_info": { "loop_action_info": {
"action_taken": action_success, "action_taken": action_success,
@@ -578,7 +500,7 @@ class HeartFChatting:
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", "")) # await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
action_type = actions[0]["action_type"] if actions else "no_action" action_type = action_to_use_info[0].action_type if action_to_use_info else "no_action"
# 管理no_action计数器当执行了非no_action动作时重置计数器 # 管理no_action计数器当执行了非no_action动作时重置计数器
if action_type != "no_action": if action_type != "no_action":
@@ -620,7 +542,7 @@ class HeartFChatting:
action_data: dict, action_data: dict,
cycle_timers: Dict[str, float], cycle_timers: Dict[str, float],
thinking_id: str, thinking_id: str,
action_message: dict, action_message: Optional["DatabaseMessages"] = None,
) -> tuple[bool, str, str]: ) -> tuple[bool, str, str]:
""" """
处理规划动作,使用动作工厂创建相应的动作处理器 处理规划动作,使用动作工厂创建相应的动作处理器
@@ -672,8 +594,8 @@ class HeartFChatting:
async def _send_response( async def _send_response(
self, self,
reply_set, reply_set,
message_data, message_data: "DatabaseMessages",
selected_expressions: List[int] = None, selected_expressions: Optional[List[int]] = None,
) -> str: ) -> str:
new_message_count = message_api.count_new_messages( new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
@@ -710,3 +632,97 @@ class HeartFChatting:
reply_text += data reply_text += data
return reply_text return reply_text
async def _execute_action(
self,
action_planner_info: ActionPlannerInfo,
chosen_action_plan_infos: List[ActionPlannerInfo],
thinking_id: str,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
):
"""执行单个动作的通用函数"""
try:
if action_planner_info.action_type == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
)
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try:
success, response_set, prompt, selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=action_planner_info.reasoning or "",
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
return_expressions=True,
)
if not success or not response_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_planner_info.action_type,
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e),
}

View File

@@ -420,7 +420,7 @@ class MessageSending(MessageProcessBase):
thinking_start_time: float = 0, thinking_start_time: float = 0,
apply_set_reply_logic: bool = False, apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None, reply_to: Optional[str] = None,
selected_expressions:List[int] = None, selected_expressions: Optional[List[int]] = None,
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(

View File

@@ -2,6 +2,7 @@ from typing import Dict, Optional, Type
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction from src.plugin_system.base.base_action import BaseAction
@@ -37,7 +38,7 @@ class ActionManager:
chat_stream: ChatStream, chat_stream: ChatStream,
log_prefix: str, log_prefix: str,
shutting_down: bool = False, shutting_down: bool = False,
action_message: Optional[dict] = None, action_message: Optional[DatabaseMessages] = None,
) -> Optional[BaseAction]: ) -> Optional[BaseAction]:
""" """
创建动作处理器实例 创建动作处理器实例
@@ -83,7 +84,7 @@ class ActionManager:
log_prefix=log_prefix, log_prefix=log_prefix,
shutting_down=shutting_down, shutting_down=shutting_down,
plugin_config=plugin_config, plugin_config=plugin_config,
action_message=action_message, action_message=action_message.flatten() if action_message else None,
) )
logger.debug(f"创建Action实例成功: {action_name}") logger.debug(f"创建Action实例成功: {action_name}")

View File

@@ -1,7 +1,7 @@
import json import json
import time import time
import traceback import traceback
from typing import Dict, Any, Optional, Tuple, List from typing import Dict, Optional, Tuple, List
from rich.traceback import install from rich.traceback import install
from datetime import datetime from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
@@ -9,6 +9,8 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_actions, build_readable_actions,
@@ -97,7 +99,9 @@ class ActionPlanner:
self.plan_retry_count = 0 self.plan_retry_count = 0
self.max_plan_retries = 3 self.max_plan_retries = 3
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: def find_message_by_id(
self, message_id: str, message_id_list: List[DatabaseMessages]
) -> Optional[DatabaseMessages]:
# sourcery skip: use-next # sourcery skip: use-next
""" """
根据message_id从message_id_list中查找对应的原始消息 根据message_id从message_id_list中查找对应的原始消息
@@ -110,11 +114,11 @@ class ActionPlanner:
找到的原始消息字典如果未找到则返回None 找到的原始消息字典如果未找到则返回None
""" """
for item in message_id_list: for item in message_id_list:
if item.get("id") == message_id: if item.message_id == message_id:
return item.get("message") return item
return None return None
def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]: def get_latest_message(self, message_id_list: List[DatabaseMessages]) -> Optional[DatabaseMessages]:
""" """
获取消息列表中的最新消息 获取消息列表中的最新消息
@@ -124,23 +128,23 @@ class ActionPlanner:
Returns: Returns:
最新的消息字典如果列表为空则返回None 最新的消息字典如果列表为空则返回None
""" """
return message_id_list[-1].get("message") if message_id_list else None return message_id_list[-1] if message_id_list else None
async def plan( async def plan(
self, self,
mode: ChatMode = ChatMode.FOCUS, mode: ChatMode = ChatMode.FOCUS,
loop_start_time:float = 0.0, loop_start_time: float = 0.0,
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]: ) -> Tuple[List[ActionPlannerInfo], Optional[DatabaseMessages]]:
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """
action = "no_action" # 默认动作 action: str = "no_action" # 默认动作
reasoning = "规划器初始化默认" reasoning: str = "规划器初始化默认"
action_data = {} action_data = {}
current_available_actions: Dict[str, ActionInfo] = {} current_available_actions: Dict[str, ActionInfo] = {}
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量 target_message: Optional[DatabaseMessages] = None # 初始化target_message变量
prompt: str = "" prompt: str = ""
message_id_list: list = [] message_id_list: list = []
@@ -208,19 +212,21 @@ class ActionPlanner:
# 如果获取的target_message为None输出warning并重新plan # 如果获取的target_message为None输出warning并重新plan
if target_message is None: if target_message is None:
self.plan_retry_count += 1 self.plan_retry_count += 1
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}") logger.warning(
f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}"
)
# 仍有重试次数 # 仍有重试次数
if self.plan_retry_count < self.max_plan_retries: if self.plan_retry_count < self.max_plan_retries:
# 递归重新plan # 递归重新plan
return await self.plan(mode, loop_start_time, available_actions) return await self.plan(mode, loop_start_time, available_actions)
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message") logger.error(
f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message"
)
target_message = self.get_latest_message(message_id_list) target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 # 重置计数器 self.plan_retry_count = 0 # 重置计数器
else: else:
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id") logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
if action != "no_action" and action != "reply" and action not in current_available_actions: if action != "no_action" and action != "reply" and action not in current_available_actions:
logger.warning( logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'" f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
@@ -244,38 +250,37 @@ class ActionPlanner:
if mode == ChatMode.NORMAL and action in current_available_actions: if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action is_parallel = current_available_actions[action].parallel_action
action_data["loop_start_time"] = loop_start_time action_data["loop_start_time"] = loop_start_time
actions = [ actions = [
{ ActionPlannerInfo(
"action_type": action, action_type=action,
"reasoning": reasoning, reasoning=reasoning,
"action_data": action_data, action_data=action_data,
"action_message": target_message, action_message=target_message,
"available_actions": available_actions, available_actions=available_actions,
} )
] ]
if action != "reply" and is_parallel: if action != "reply" and is_parallel:
actions.append({ actions.append(
"action_type": "reply", ActionPlannerInfo(
"action_message": target_message, action_type="reply",
"available_actions": available_actions action_message=target_message,
}) available_actions=available_actions,
)
return actions,target_message )
return actions, target_message
async def build_planner_prompt( async def build_planner_prompt(
self, self,
is_group_chat: bool, # Now passed as argument is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument chat_target_info: Optional[dict], # Now passed as argument
current_available_actions: Dict[str, ActionInfo], current_available_actions: Dict[str, ActionInfo],
refresh_time :bool = False, refresh_time: bool = False,
mode: ChatMode = ChatMode.FOCUS, mode: ChatMode = ChatMode.FOCUS,
) -> tuple[str, list]: # sourcery skip: use-join ) -> tuple[str, List[DatabaseMessages]]: # sourcery skip: use-join
"""构建 Planner LLM 的提示词 (获取模板并填充数据)""" """构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try: try:
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@@ -312,7 +317,6 @@ class ActionPlanner:
if global_config.chat.at_bot_inevitable_reply: if global_config.chat.at_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你或者at你" mentioned_bonus = "\n- 有人提到你或者at你"
if mode == ChatMode.FOCUS: if mode == ChatMode.FOCUS:
no_action_block = """ no_action_block = """
动作no_action 动作no_action
@@ -388,7 +392,7 @@ class ActionPlanner:
action_options_text=action_options_block, action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block, moderation_prompt=moderation_prompt_block,
identity_block=identity_block, identity_block=identity_block,
plan_style = global_config.personality.plan_style plan_style=global_config.personality.plan_style,
) )
return prompt, message_id_list return prompt, message_id_list
except Exception as e: except Exception as e:

View File

@@ -9,6 +9,7 @@ from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@@ -157,12 +158,12 @@ class DefaultReplyer:
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None, chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True, enable_tool: bool = True,
from_plugin: bool = True, from_plugin: bool = True,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional[DatabaseMessages] = None,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]: ) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]:
# sourcery skip: merge-nested-ifs # sourcery skip: merge-nested-ifs
""" """
回复器 (Replier): 负责生成回复文本的核心逻辑。 回复器 (Replier): 负责生成回复文本的核心逻辑。
@@ -181,7 +182,7 @@ class DefaultReplyer:
""" """
prompt = None prompt = None
selected_expressions = None selected_expressions: Optional[List[int]] = None
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
try: try:
@@ -374,7 +375,12 @@ class DefaultReplyer:
) )
if global_config.memory.enable_instant_memory: if global_config.memory.enable_instant_memory:
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history)) chat_history_str = build_readable_messages(
messages=chat_history,
replace_bot_name=True,
timestamp_mode="normal"
)
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str))
instant_memory = await self.instant_memory.get_memory(target) instant_memory = await self.instant_memory.get_memory(target)
logger.info(f"即时记忆:{instant_memory}") logger.info(f"即时记忆:{instant_memory}")
@@ -527,7 +533,7 @@ class DefaultReplyer:
Returns: Returns:
Tuple[str, str]: (核心对话prompt, 背景对话prompt) Tuple[str, str]: (核心对话prompt, 背景对话prompt)
""" """
core_dialogue_list = [] core_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话 # 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
@@ -559,7 +565,7 @@ class DefaultReplyer:
if core_dialogue_list: if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息 # 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages) has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}") # logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}") # logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
@@ -634,7 +640,7 @@ class DefaultReplyer:
return mai_think return mai_think
async def build_actions_prompt( async def build_actions_prompt(
self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str: ) -> str:
"""构建动作提示""" """构建动作提示"""
@@ -646,20 +652,21 @@ class DefaultReplyer:
action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n" action_descriptions += "\n"
choosen_action_descriptions = "" chosen_action_descriptions = ""
if choosen_actions: if chosen_actions_info:
for action in choosen_actions: for action_plan_info in chosen_actions_info:
action_name = action.get("action_type", "unknown_action") action_name = action_plan_info.action_type
if action_name == "reply": if action_name == "reply":
continue continue
action_description = action.get("reason", "无描述") if action := available_actions.get(action_name):
reasoning = action.get("reasoning", "原因") action_description = action.description or "描述"
reasoning = action_plan_info.reasoning or "无原因"
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n" chosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
if choosen_action_descriptions: if chosen_action_descriptions:
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n" action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
action_descriptions += choosen_action_descriptions action_descriptions += chosen_action_descriptions
return action_descriptions return action_descriptions
@@ -668,9 +675,9 @@ class DefaultReplyer:
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None, chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True, enable_tool: bool = True,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional[DatabaseMessages] = None,
) -> Tuple[str, List[int]]: ) -> Tuple[str, List[int]]:
""" """
构建回复器上下文 构建回复器上下文
@@ -694,11 +701,11 @@ class DefaultReplyer:
platform = chat_stream.platform platform = chat_stream.platform
if reply_message: if reply_message:
user_id = reply_message.get("user_id", "") user_id = reply_message.user_info.user_id
person = Person(platform=platform, user_id=user_id) person = Person(platform=platform, user_id=user_id)
person_name = person.person_name or user_id person_name = person.person_name or user_id
sender = person_name sender = person_name
target = reply_message.get("processed_plain_text") target = reply_message.processed_plain_text
else: else:
person_name = "用户" person_name = "用户"
sender = "用户" sender = "用户"
@@ -774,11 +781,13 @@ class DefaultReplyer:
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s") logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
expression_habits_block, selected_expressions = results_dict["expression_habits"] expression_habits_block, selected_expressions = results_dict["expression_habits"]
relation_info = results_dict["relation_info"] expression_habits_block: str
memory_block = results_dict["memory_block"] selected_expressions: List[int]
tool_info = results_dict["tool_info"] relation_info: str = results_dict["relation_info"]
prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果 memory_block: str = results_dict["memory_block"]
actions_info = results_dict["actions_info"] tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target) keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
if extra_info: if extra_info:

View File

@@ -11,7 +11,6 @@ from collections import Counter
from typing import Optional, Tuple, Dict, List, Any from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
@@ -641,6 +640,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
platform: str = chat_stream.platform platform: str = chat_stream.platform
user_id: str = user_info.user_id # type: ignore user_id: str = user_info.user_id # type: ignore
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
# Initialize target_info with basic info # Initialize target_info with basic info
target_info = TargetPersonInfo( target_info = TargetPersonInfo(
platform=platform, platform=platform,

View File

@@ -1,4 +1,4 @@
from typing import Optional, Any from typing import Optional, Any, Dict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from . import BaseDataModel from . import BaseDataModel
@@ -157,3 +157,42 @@ class DatabaseMessages(BaseDataModel):
# assert isinstance(self.interest_value, float) or self.interest_value is None, ( # assert isinstance(self.interest_value, float) or self.interest_value is None, (
# "interest_value must be a float or None" # "interest_value must be a float or None"
# ) # )
def flatten(self) -> Dict[str, Any]:
"""
将消息数据模型转换为字典格式,便于存储或传输
"""
return {
"message_id": self.message_id,
"time": self.time,
"chat_id": self.chat_id,
"reply_to": self.reply_to,
"interest_value": self.interest_value,
"key_words": self.key_words,
"key_words_lite": self.key_words_lite,
"is_mentioned": self.is_mentioned,
"processed_plain_text": self.processed_plain_text,
"display_message": self.display_message,
"priority_mode": self.priority_mode,
"priority_info": self.priority_info,
"additional_config": self.additional_config,
"is_emoji": self.is_emoji,
"is_picid": self.is_picid,
"is_command": self.is_command,
"is_notify": self.is_notify,
"selected_expressions": self.selected_expressions,
"user_id": self.user_info.user_id,
"user_nickname": self.user_info.user_nickname,
"user_cardname": self.user_info.user_cardname,
"user_platform": self.user_info.platform,
"chat_info_group_id": self.group_info.group_id if self.group_info else None,
"chat_info_group_name": self.group_info.group_name if self.group_info else None,
"chat_info_group_platform": self.group_info.group_platform if self.group_info else None,
"chat_info_stream_id": self.chat_info.stream_id,
"chat_info_platform": self.chat_info.platform,
"chat_info_create_time": self.chat_info.create_time,
"chat_info_last_active_time": self.chat_info.last_active_time,
"chat_info_user_platform": self.chat_info.user_info.platform,
"chat_info_user_id": self.chat_info.user_info.user_id,
"chat_info_user_nickname": self.chat_info.user_info.user_nickname,
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}

View File

@@ -1,8 +1,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, Dict, TYPE_CHECKING
from . import BaseDataModel from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo
@dataclass @dataclass
class TargetPersonInfo(BaseDataModel): class TargetPersonInfo(BaseDataModel):
platform: str = field(default_factory=str) platform: str = field(default_factory=str)
@@ -10,3 +14,12 @@ class TargetPersonInfo(BaseDataModel):
user_nickname: str = field(default_factory=str) user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None person_id: Optional[str] = None
person_name: Optional[str] = None person_name: Optional[str] = None
@dataclass
class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None

View File

@@ -9,7 +9,7 @@
""" """
import traceback import traceback
from typing import Tuple, Any, Dict, List, Optional from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
@@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("generator_api") logger = get_logger("generator_api")
@@ -73,11 +77,11 @@ async def generate_reply(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None, action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional["DatabaseMessages"] = None,
extra_info: str = "", extra_info: str = "",
reply_reason: str = "", reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None, chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
enable_tool: bool = False, enable_tool: bool = False,
enable_splitter: bool = True, enable_splitter: bool = True,
enable_chinese_typo: bool = True, enable_chinese_typo: bool = True,
@@ -85,7 +89,7 @@ async def generate_reply(
request_type: str = "generator_api", request_type: str = "generator_api",
from_plugin: bool = True, from_plugin: bool = True,
return_expressions: bool = False, return_expressions: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]: ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
"""生成回复 """生成回复
Args: Args:
@@ -96,7 +100,7 @@ async def generate_reply(
extra_info: 额外信息,用于补充上下文 extra_info: 额外信息,用于补充上下文
reply_reason: 回复原因 reply_reason: 回复原因
available_actions: 可用动作 available_actions: 可用动作
choosen_actions: 已选动作 chosen_actions: 已选动作
enable_tool: 是否启用工具调用 enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
@@ -110,12 +114,10 @@ async def generate_reply(
try: try:
# 获取回复器 # 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复") logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer( replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
chat_stream, chat_id, request_type=request_type
)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None return False, [], None, None
if not extra_info and action_data: if not extra_info and action_data:
extra_info = action_data.get("extra_info", "") extra_info = action_data.get("extra_info", "")
@@ -127,7 +129,7 @@ async def generate_reply(
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context( success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
chosen_actions=choosen_actions, chosen_actions=chosen_actions,
enable_tool=enable_tool, enable_tool=enable_tool,
reply_message=reply_message, reply_message=reply_message,
reply_reason=reply_reason, reply_reason=reply_reason,
@@ -136,7 +138,7 @@ async def generate_reply(
) )
if not success: if not success:
logger.warning("[GeneratorAPI] 回复生成失败") logger.warning("[GeneratorAPI] 回复生成失败")
return False, [], None return False, [], None, None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况 assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""): if content := llm_response_dict.get("content", ""):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
@@ -144,16 +146,22 @@ async def generate_reply(
reply_set = [] reply_set = []
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项") logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
if return_prompt: # if return_prompt:
if return_expressions: # if return_expressions:
return success, reply_set, (prompt, selected_expressions) # return success, reply_set, prompt, selected_expressions
else: # else:
return success, reply_set, prompt # return success, reply_set, prompt, None
else: # else:
if return_expressions: # if return_expressions:
return success, reply_set, (None, selected_expressions) # return success, reply_set, (None, selected_expressions)
else: # else:
return success, reply_set, None # return success, reply_set, None
return (
success,
reply_set,
prompt if return_prompt else None,
selected_expressions if return_expressions else None,
)
except ValueError as ve: except ValueError as ve:
raise ve raise ve

View File

@@ -21,15 +21,17 @@
import traceback import traceback
import time import time
from typing import Optional, Union, Dict, Any, List from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING
from src.common.logger import get_logger
# 导入依赖 from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg, UserInfo from maim_message import Seg, UserInfo
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("send_api") logger = get_logger("send_api")
@@ -46,10 +48,10 @@ async def _send_to_target(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
selected_expressions:List[int] = None, selected_expressions: Optional[List[int]] = None,
) -> bool: ) -> bool:
"""向指定目标发送消息的内部实现 """向指定目标发送消息的内部实现
@@ -98,7 +100,7 @@ async def _send_to_target(
message_segment = Seg(type=message_type, data=content) # type: ignore message_segment = Seg(type=message_type, data=content) # type: ignore
if reply_message: if reply_message:
anchor_message = message_dict_to_message_recv(reply_message) anchor_message = message_dict_to_message_recv(reply_message.flatten())
if anchor_message: if anchor_message:
anchor_message.update_chat_stream(target_stream) anchor_message.update_chat_stream(target_stream)
assert anchor_message.message_info.user_info, "用户信息缺失" assert anchor_message.message_info.user_info, "用户信息缺失"
@@ -197,7 +199,6 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
return message_recv return message_recv
# ============================================================================= # =============================================================================
# 公共API函数 - 预定义类型的发送函数 # 公共API函数 - 预定义类型的发送函数
# ============================================================================= # =============================================================================
@@ -208,9 +209,9 @@ async def text_to_stream(
stream_id: str, stream_id: str,
typing: bool = False, typing: bool = False,
set_reply: bool = False, set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True, storage_message: bool = True,
selected_expressions:List[int] = None, selected_expressions: Optional[List[int]] = None,
) -> bool: ) -> bool:
"""向指定流发送文本消息 """向指定流发送文本消息
@@ -237,7 +238,13 @@ async def text_to_stream(
) )
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def emoji_to_stream(
emoji_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送表情包 """向指定流发送表情包
Args: Args:
@@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) return await _send_to_target(
"emoji",
emoji_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def image_to_stream(
image_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送图片 """向指定流发送图片
Args: Args:
@@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message) return await _send_to_target(
"image",
image_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def command_to_stream( async def command_to_stream(
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None command: Union[str, dict],
stream_id: str,
storage_message: bool = True,
display_message: str = "",
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""向指定流发送命令 """向指定流发送命令
@@ -279,7 +315,14 @@ async def command_to_stream(
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target( return await _send_to_target(
"command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message "command",
command,
stream_id,
display_message,
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
) )
@@ -289,7 +332,7 @@ async def custom_to_stream(
stream_id: str, stream_id: str,
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_message: Optional[Dict[str, Any]] = None, reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False, set_reply: bool = False,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,

View File

@@ -2,13 +2,15 @@ import time
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any from typing import Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action") logger = get_logger("base_action")
@@ -206,7 +208,11 @@ class BaseAction(ABC):
return False, f"等待新消息失败: {str(e)}" return False, f"等待新消息失败: {str(e)}"
async def send_text( async def send_text(
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
) -> bool: ) -> bool:
"""发送文本消息 """发送文本消息
@@ -229,7 +235,9 @@ class BaseAction(ABC):
typing=typing, typing=typing,
) )
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包 """发送表情包
Args: Args:
@@ -242,9 +250,13 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID") logger.error(f"{self.log_prefix} 缺少聊天ID")
return False return False
return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) return await send_api.emoji_to_stream(
emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def send_image(
self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送图片 """发送图片
Args: Args:
@@ -257,9 +269,18 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID") logger.error(f"{self.log_prefix} 缺少聊天ID")
return False return False
return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message) return await send_api.image_to_stream(
image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def send_custom(
self,
message_type: str,
content: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送自定义类型消息 """发送自定义类型消息
Args: Args:
@@ -308,7 +329,13 @@ class BaseAction(ABC):
) )
async def send_command( async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""发送命令消息 """发送命令消息

View File

@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, Any from typing import Dict, Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import CommandInfo, ComponentType from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_command") logger = get_logger("base_command")
@@ -84,7 +87,13 @@ class BaseCommand(ABC):
return current return current
async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送回复消息 """发送回复消息
Args: Args:
@@ -100,10 +109,22 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) return await send_api.text_to_stream(
text=content,
stream_id=chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_type( async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None self,
message_type: str,
content: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""发送指定类型的回复消息到当前聊天环境 """发送指定类型的回复消息到当前聊天环境
@@ -134,7 +155,13 @@ class BaseCommand(ABC):
) )
async def send_command( async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool: ) -> bool:
"""发送命令消息 """发送命令消息
@@ -177,7 +204,9 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 发送命令时出错: {e}") logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False return False
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool: async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包 """发送表情包
Args: Args:
@@ -191,9 +220,17 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message) return await send_api.emoji_to_stream(
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool: async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片 """发送图片
Args: Args:
@@ -207,7 +244,13 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message) return await send_api.image_to_stream(
image_base64,
chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
@classmethod @classmethod
def get_command_info(cls) -> "CommandInfo": def get_command_info(cls) -> "CommandInfo":