Merge branch 'MaiM-with-u:dev' into dev
This commit is contained in:
@@ -3,13 +3,13 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Dict, Any, Tuple
|
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
@@ -24,12 +24,15 @@ from src.chat.frequency_control.focus_value_control import focus_value_control
|
|||||||
from src.chat.express.expression_learner import expression_learner_manager
|
from src.chat.express.expression_learner import expression_learner_manager
|
||||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from src.plugin_system.base.component_types import ChatMode, EventType
|
from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
|
||||||
from src.plugin_system.core import events_manager
|
from src.plugin_system.core import events_manager
|
||||||
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
|
||||||
from src.mais4u.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.mais4u.s4u_config import s4u_config
|
from src.mais4u.s4u_config import s4u_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
|
||||||
ERROR_LOOP_INFO = {
|
ERROR_LOOP_INFO = {
|
||||||
"loop_plan_info": {
|
"loop_plan_info": {
|
||||||
@@ -100,7 +103,7 @@ class HeartFChatting:
|
|||||||
self.reply_timeout_count = 0
|
self.reply_timeout_count = 0
|
||||||
self.plan_timeout_count = 0
|
self.plan_timeout_count = 0
|
||||||
|
|
||||||
self.last_read_time = time.time() - 1
|
self.last_read_time = time.time() - 10
|
||||||
|
|
||||||
self.focus_energy = 1
|
self.focus_energy = 1
|
||||||
self.no_action_consecutive = 0
|
self.no_action_consecutive = 0
|
||||||
@@ -141,7 +144,7 @@ class HeartFChatting:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
|
||||||
|
|
||||||
def start_cycle(self):
|
def start_cycle(self) -> Tuple[Dict[str, float], str]:
|
||||||
self._cycle_counter += 1
|
self._cycle_counter += 1
|
||||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||||
@@ -172,7 +175,8 @@ class HeartFChatting:
|
|||||||
action_type = action_result.get("action_type", "未知动作")
|
action_type = action_result.get("action_type", "未知动作")
|
||||||
elif isinstance(action_result, list) and action_result:
|
elif isinstance(action_result, list) and action_result:
|
||||||
# 新格式:action_result是actions列表
|
# 新格式:action_result是actions列表
|
||||||
action_type = action_result[0].get("action_type", "未知动作")
|
# TODO: 把这里写明白
|
||||||
|
action_type = action_result[0].action_type or "未知动作"
|
||||||
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
elif isinstance(loop_plan_info, list) and loop_plan_info:
|
||||||
# 直接是actions列表的情况
|
# 直接是actions列表的情况
|
||||||
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
action_type = loop_plan_info[0].get("action_type", "未知动作")
|
||||||
@@ -207,7 +211,7 @@ class HeartFChatting:
|
|||||||
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
|
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
|
||||||
self.focus_energy = 1
|
self.focus_energy = 1
|
||||||
|
|
||||||
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
|
async def _should_process_messages(self, new_message: List["DatabaseMessages"]) -> tuple[bool, float]:
|
||||||
"""
|
"""
|
||||||
判断是否应该处理消息
|
判断是否应该处理消息
|
||||||
|
|
||||||
@@ -265,7 +269,7 @@ class HeartFChatting:
|
|||||||
return False, 0.0
|
return False, 0.0
|
||||||
|
|
||||||
async def _loopbody(self):
|
async def _loopbody(self):
|
||||||
recent_messages_dict = message_api.get_messages_by_time_in_chat(
|
recent_messages_list = message_api.get_messages_by_time_in_chat(
|
||||||
chat_id=self.stream_id,
|
chat_id=self.stream_id,
|
||||||
start_time=self.last_read_time,
|
start_time=self.last_read_time,
|
||||||
end_time=time.time(),
|
end_time=time.time(),
|
||||||
@@ -275,7 +279,7 @@ class HeartFChatting:
|
|||||||
filter_command=True,
|
filter_command=True,
|
||||||
)
|
)
|
||||||
# 统一的消息处理逻辑
|
# 统一的消息处理逻辑
|
||||||
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
|
should_process, interest_value = await self._should_process_messages(recent_messages_list)
|
||||||
|
|
||||||
if should_process:
|
if should_process:
|
||||||
self.last_read_time = time.time()
|
self.last_read_time = time.time()
|
||||||
@@ -290,11 +294,11 @@ class HeartFChatting:
|
|||||||
async def _send_and_store_reply(
|
async def _send_and_store_reply(
|
||||||
self,
|
self,
|
||||||
response_set,
|
response_set,
|
||||||
action_message,
|
action_message: "DatabaseMessages",
|
||||||
cycle_timers: Dict[str, float],
|
cycle_timers: Dict[str, float],
|
||||||
thinking_id,
|
thinking_id,
|
||||||
actions,
|
actions,
|
||||||
selected_expressions: List[int] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
|
||||||
with Timer("回复发送", cycle_timers):
|
with Timer("回复发送", cycle_timers):
|
||||||
reply_text = await self._send_response(
|
reply_text = await self._send_response(
|
||||||
@@ -304,11 +308,11 @@ class HeartFChatting:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
# 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
|
||||||
platform = action_message.get("chat_info_platform")
|
platform = action_message.chat_info.platform
|
||||||
if platform is None:
|
if platform is None:
|
||||||
platform = getattr(self.chat_stream, "platform", "unknown")
|
platform = getattr(self.chat_stream, "platform", "unknown")
|
||||||
|
|
||||||
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
|
person = Person(platform=platform, user_id=action_message.user_info.user_id)
|
||||||
person_name = person.person_name
|
person_name = person.person_name
|
||||||
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
|
||||||
|
|
||||||
@@ -353,9 +357,13 @@ class HeartFChatting:
|
|||||||
k = 2.0 # 控制曲线陡峭程度
|
k = 2.0 # 控制曲线陡峭程度
|
||||||
x0 = 1.0 # 控制曲线中心点
|
x0 = 1.0 # 控制曲线中心点
|
||||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||||
|
|
||||||
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency()
|
normal_mode_probability = (
|
||||||
|
calculate_normal_mode_probability(interest_value)
|
||||||
|
* 2
|
||||||
|
* self.talk_frequency_control.get_current_talk_frequency()
|
||||||
|
)
|
||||||
|
|
||||||
# 根据概率决定使用哪种模式
|
# 根据概率决定使用哪种模式
|
||||||
if random.random() < normal_mode_probability:
|
if random.random() < normal_mode_probability:
|
||||||
mode = ChatMode.NORMAL
|
mode = ChatMode.NORMAL
|
||||||
@@ -383,17 +391,17 @@ class HeartFChatting:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
|
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
|
||||||
|
|
||||||
|
available_actions: Dict[str, ActionInfo] = {}
|
||||||
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
|
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
|
||||||
# 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
# 如果激活度没有激活,并且聊天活跃度低,有可能不进行plan,相当于不在电脑前,不进行认真思考
|
||||||
actions = [
|
action_to_use_info = [
|
||||||
{
|
ActionPlannerInfo(
|
||||||
"action_type": "no_action",
|
action_type="no_action",
|
||||||
"reasoning": "专注不足",
|
reasoning="专注不足",
|
||||||
"action_data": {},
|
action_data={},
|
||||||
}
|
)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
available_actions = {}
|
|
||||||
# 第一步:动作修改
|
# 第一步:动作修改
|
||||||
with Timer("动作修改", cycle_timers):
|
with Timer("动作修改", cycle_timers):
|
||||||
try:
|
try:
|
||||||
@@ -414,105 +422,19 @@ class HeartFChatting:
|
|||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
with Timer("规划器", cycle_timers):
|
with Timer("规划器", cycle_timers):
|
||||||
actions, _ = await self.action_planner.plan(
|
action_to_use_info, _ = await self.action_planner.plan(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
loop_start_time=self.last_read_time,
|
loop_start_time=self.last_read_time,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 并行执行所有动作
|
# 3. 并行执行所有动作
|
||||||
async def execute_action(action_info, actions):
|
action_tasks = [
|
||||||
"""执行单个动作的通用函数"""
|
asyncio.create_task(
|
||||||
try:
|
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
|
||||||
if action_info["action_type"] == "no_action":
|
)
|
||||||
# 直接处理no_action逻辑,不再通过动作系统
|
for action in action_to_use_info
|
||||||
reason = action_info.get("reasoning", "选择不回复")
|
]
|
||||||
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
|
||||||
|
|
||||||
# 存储no_action信息到数据库
|
|
||||||
await database_api.store_action_info(
|
|
||||||
chat_stream=self.chat_stream,
|
|
||||||
action_build_into_prompt=False,
|
|
||||||
action_prompt_display=reason,
|
|
||||||
action_done=True,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
action_data={"reason": reason},
|
|
||||||
action_name="no_action",
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
|
||||||
elif action_info["action_type"] != "reply":
|
|
||||||
# 执行普通动作
|
|
||||||
with Timer("动作执行", cycle_timers):
|
|
||||||
success, reply_text, command = await self._handle_action(
|
|
||||||
action_info["action_type"],
|
|
||||||
action_info["reasoning"],
|
|
||||||
action_info["action_data"],
|
|
||||||
cycle_timers,
|
|
||||||
thinking_id,
|
|
||||||
action_info["action_message"],
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"action_type": action_info["action_type"],
|
|
||||||
"success": success,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"command": command,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
|
|
||||||
chat_stream=self.chat_stream,
|
|
||||||
reply_message=action_info["action_message"],
|
|
||||||
available_actions=available_actions,
|
|
||||||
choosen_actions=actions,
|
|
||||||
reply_reason=action_info.get("reasoning", ""),
|
|
||||||
enable_tool=global_config.tool.enable_tool,
|
|
||||||
request_type="replyer",
|
|
||||||
from_plugin=False,
|
|
||||||
return_expressions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
|
|
||||||
_, selected_expressions = prompt_selected_expressions
|
|
||||||
else:
|
|
||||||
selected_expressions = []
|
|
||||||
|
|
||||||
if not success or not response_set:
|
|
||||||
logger.info(
|
|
||||||
f"对 {action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
|
|
||||||
)
|
|
||||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
|
||||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
|
||||||
|
|
||||||
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
|
|
||||||
response_set=response_set,
|
|
||||||
action_message=action_info["action_message"],
|
|
||||||
cycle_timers=cycle_timers,
|
|
||||||
thinking_id=thinking_id,
|
|
||||||
actions=actions,
|
|
||||||
selected_expressions=selected_expressions,
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"action_type": "reply",
|
|
||||||
"success": True,
|
|
||||||
"reply_text": reply_text,
|
|
||||||
"loop_info": loop_info,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
|
||||||
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
|
||||||
return {
|
|
||||||
"action_type": action_info["action_type"],
|
|
||||||
"success": False,
|
|
||||||
"reply_text": "",
|
|
||||||
"loop_info": None,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
|
|
||||||
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
|
|
||||||
|
|
||||||
# 并行执行所有任务
|
# 并行执行所有任务
|
||||||
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
results = await asyncio.gather(*action_tasks, return_exceptions=True)
|
||||||
@@ -529,7 +451,7 @@ class HeartFChatting:
|
|||||||
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_cur_action = actions[i]
|
_cur_action = action_to_use_info[i]
|
||||||
if result["action_type"] != "reply":
|
if result["action_type"] != "reply":
|
||||||
action_success = result["success"]
|
action_success = result["success"]
|
||||||
action_reply_text = result["reply_text"]
|
action_reply_text = result["reply_text"]
|
||||||
@@ -558,7 +480,7 @@ class HeartFChatting:
|
|||||||
# 没有回复信息,构建纯动作的loop_info
|
# 没有回复信息,构建纯动作的loop_info
|
||||||
loop_info = {
|
loop_info = {
|
||||||
"loop_plan_info": {
|
"loop_plan_info": {
|
||||||
"action_result": actions,
|
"action_result": action_to_use_info,
|
||||||
},
|
},
|
||||||
"loop_action_info": {
|
"loop_action_info": {
|
||||||
"action_taken": action_success,
|
"action_taken": action_success,
|
||||||
@@ -578,7 +500,7 @@ class HeartFChatting:
|
|||||||
|
|
||||||
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||||
|
|
||||||
action_type = actions[0]["action_type"] if actions else "no_action"
|
action_type = action_to_use_info[0].action_type if action_to_use_info else "no_action"
|
||||||
|
|
||||||
# 管理no_action计数器:当执行了非no_action动作时,重置计数器
|
# 管理no_action计数器:当执行了非no_action动作时,重置计数器
|
||||||
if action_type != "no_action":
|
if action_type != "no_action":
|
||||||
@@ -620,7 +542,7 @@ class HeartFChatting:
|
|||||||
action_data: dict,
|
action_data: dict,
|
||||||
cycle_timers: Dict[str, float],
|
cycle_timers: Dict[str, float],
|
||||||
thinking_id: str,
|
thinking_id: str,
|
||||||
action_message: dict,
|
action_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> tuple[bool, str, str]:
|
) -> tuple[bool, str, str]:
|
||||||
"""
|
"""
|
||||||
处理规划动作,使用动作工厂创建相应的动作处理器
|
处理规划动作,使用动作工厂创建相应的动作处理器
|
||||||
@@ -672,8 +594,8 @@ class HeartFChatting:
|
|||||||
async def _send_response(
|
async def _send_response(
|
||||||
self,
|
self,
|
||||||
reply_set,
|
reply_set,
|
||||||
message_data,
|
message_data: "DatabaseMessages",
|
||||||
selected_expressions: List[int] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
new_message_count = message_api.count_new_messages(
|
new_message_count = message_api.count_new_messages(
|
||||||
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
|
||||||
@@ -710,3 +632,97 @@ class HeartFChatting:
|
|||||||
reply_text += data
|
reply_text += data
|
||||||
|
|
||||||
return reply_text
|
return reply_text
|
||||||
|
|
||||||
|
async def _execute_action(
|
||||||
|
self,
|
||||||
|
action_planner_info: ActionPlannerInfo,
|
||||||
|
chosen_action_plan_infos: List[ActionPlannerInfo],
|
||||||
|
thinking_id: str,
|
||||||
|
available_actions: Dict[str, ActionInfo],
|
||||||
|
cycle_timers: Dict[str, float],
|
||||||
|
):
|
||||||
|
"""执行单个动作的通用函数"""
|
||||||
|
try:
|
||||||
|
if action_planner_info.action_type == "no_action":
|
||||||
|
# 直接处理no_action逻辑,不再通过动作系统
|
||||||
|
reason = action_planner_info.reasoning or "选择不回复"
|
||||||
|
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
|
||||||
|
|
||||||
|
# 存储no_action信息到数据库
|
||||||
|
await database_api.store_action_info(
|
||||||
|
chat_stream=self.chat_stream,
|
||||||
|
action_build_into_prompt=False,
|
||||||
|
action_prompt_display=reason,
|
||||||
|
action_done=True,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
action_data={"reason": reason},
|
||||||
|
action_name="no_action",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
|
||||||
|
elif action_planner_info.action_type != "reply":
|
||||||
|
# 执行普通动作
|
||||||
|
with Timer("动作执行", cycle_timers):
|
||||||
|
success, reply_text, command = await self._handle_action(
|
||||||
|
action_planner_info.action_type,
|
||||||
|
action_planner_info.reasoning or "",
|
||||||
|
action_planner_info.action_data or {},
|
||||||
|
cycle_timers,
|
||||||
|
thinking_id,
|
||||||
|
action_planner_info.action_message,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"action_type": action_planner_info.action_type,
|
||||||
|
"success": success,
|
||||||
|
"reply_text": reply_text,
|
||||||
|
"command": command,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
success, response_set, prompt, selected_expressions = await generator_api.generate_reply(
|
||||||
|
chat_stream=self.chat_stream,
|
||||||
|
reply_message=action_planner_info.action_message,
|
||||||
|
available_actions=available_actions,
|
||||||
|
chosen_actions=chosen_action_plan_infos,
|
||||||
|
reply_reason=action_planner_info.reasoning or "",
|
||||||
|
enable_tool=global_config.tool.enable_tool,
|
||||||
|
request_type="replyer",
|
||||||
|
from_plugin=False,
|
||||||
|
return_expressions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success or not response_set:
|
||||||
|
if action_planner_info.action_message:
|
||||||
|
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||||
|
else:
|
||||||
|
logger.info("回复生成失败")
|
||||||
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||||
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
|
|
||||||
|
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||||
|
response_set=response_set,
|
||||||
|
action_message=action_planner_info.action_message, # type: ignore
|
||||||
|
cycle_timers=cycle_timers,
|
||||||
|
thinking_id=thinking_id,
|
||||||
|
actions=chosen_action_plan_infos,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"action_type": "reply",
|
||||||
|
"success": True,
|
||||||
|
"reply_text": reply_text,
|
||||||
|
"loop_info": loop_info,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
|
||||||
|
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
|
||||||
|
return {
|
||||||
|
"action_type": action_planner_info.action_type,
|
||||||
|
"success": False,
|
||||||
|
"reply_text": "",
|
||||||
|
"loop_info": None,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|||||||
@@ -303,4 +303,4 @@ init_prompt()
|
|||||||
try:
|
try:
|
||||||
expression_selector = ExpressionSelector()
|
expression_selector = ExpressionSelector()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ExpressionSelector初始化失败: {e}")
|
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||||
|
|||||||
@@ -4,44 +4,43 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
|||||||
|
|
||||||
|
|
||||||
class FocusValueControl:
|
class FocusValueControl:
|
||||||
def __init__(self,chat_id:str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.focus_value_adjust = 1
|
self.focus_value_adjust: float = 1
|
||||||
|
|
||||||
|
|
||||||
def get_current_focus_value(self) -> float:
|
def get_current_focus_value(self) -> float:
|
||||||
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
|
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
|
||||||
|
|
||||||
|
|
||||||
class FocusValueControlManager:
|
class FocusValueControlManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.focus_value_controls = {}
|
self.focus_value_controls: dict[str, FocusValueControl] = {}
|
||||||
|
|
||||||
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
|
def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
|
||||||
if chat_id not in self.focus_value_controls:
|
if chat_id not in self.focus_value_controls:
|
||||||
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
|
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
|
||||||
return self.focus_value_controls[chat_id]
|
return self.focus_value_controls[chat_id]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 focus_value
|
根据当前时间和聊天流获取对应的 focus_value
|
||||||
"""
|
"""
|
||||||
if not global_config.chat.focus_value_adjust:
|
if not global_config.chat.focus_value_adjust:
|
||||||
return global_config.chat.focus_value
|
return global_config.chat.focus_value
|
||||||
|
|
||||||
if chat_id:
|
if chat_id:
|
||||||
stream_focus_value = get_stream_specific_focus_value(chat_id)
|
stream_focus_value = get_stream_specific_focus_value(chat_id)
|
||||||
if stream_focus_value is not None:
|
if stream_focus_value is not None:
|
||||||
return stream_focus_value
|
return stream_focus_value
|
||||||
|
|
||||||
global_focus_value = get_global_focus_value()
|
global_focus_value = get_global_focus_value()
|
||||||
if global_focus_value is not None:
|
if global_focus_value is not None:
|
||||||
return global_focus_value
|
return global_focus_value
|
||||||
|
|
||||||
return global_config.chat.focus_value
|
return global_config.chat.focus_value
|
||||||
|
|
||||||
|
|
||||||
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
获取特定聊天流在当前时间的专注度
|
获取特定聊天流在当前时间的专注度
|
||||||
@@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
focus_value_control = FocusValueControlManager()
|
|
||||||
|
focus_value_control = FocusValueControlManager()
|
||||||
|
|||||||
@@ -2,20 +2,21 @@ from typing import Optional
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||||
|
|
||||||
|
|
||||||
class TalkFrequencyControl:
|
class TalkFrequencyControl:
|
||||||
def __init__(self,chat_id:str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.talk_frequency_adjust = 1
|
self.talk_frequency_adjust: float = 1
|
||||||
|
|
||||||
def get_current_talk_frequency(self) -> float:
|
def get_current_talk_frequency(self) -> float:
|
||||||
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
|
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
|
||||||
|
|
||||||
|
|
||||||
class TalkFrequencyControlManager:
|
class TalkFrequencyControlManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.talk_frequency_controls = {}
|
self.talk_frequency_controls = {}
|
||||||
|
|
||||||
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
|
def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl:
|
||||||
if chat_id not in self.talk_frequency_controls:
|
if chat_id not in self.talk_frequency_controls:
|
||||||
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
|
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
|
||||||
return self.talk_frequency_controls[chat_id]
|
return self.talk_frequency_controls[chat_id]
|
||||||
@@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
|
|||||||
global_frequency = get_global_frequency()
|
global_frequency = get_global_frequency()
|
||||||
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
||||||
|
|
||||||
|
|
||||||
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
根据时间配置列表获取当前时段的频率
|
根据时间配置列表获取当前时段的频率
|
||||||
@@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_global_frequency() -> Optional[float]:
|
def get_global_frequency() -> Optional[float]:
|
||||||
"""
|
"""
|
||||||
获取全局默认频率配置
|
获取全局默认频率配置
|
||||||
@@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
talk_frequency_control = TalkFrequencyControlManager()
|
|
||||||
|
talk_frequency_control = TalkFrequencyControlManager()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from src.chat.message_receive.storage import MessageStorage
|
|||||||
from src.chat.heart_flow.heartflow import heartflow
|
from src.chat.heart_flow.heartflow import heartflow
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.utils.chat_message_builder import replace_user_references_sync
|
from src.chat.utils.chat_message_builder import replace_user_references
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
@@ -131,7 +131,7 @@ class HeartFCMessageReceiver:
|
|||||||
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
|
||||||
|
|
||||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||||
processed_plain_text = replace_user_references_sync(
|
processed_plain_text = replace_user_references(
|
||||||
processed_plain_text,
|
processed_plain_text,
|
||||||
message.message_info.platform, # type: ignore
|
message.message_info.platform, # type: ignore
|
||||||
replace_bot_name=True
|
replace_bot_name=True
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||||
|
from src.chat.knowledge.qa_manager import QAManager
|
||||||
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
|
from src.chat.knowledge.global_logger import logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
import os
|
||||||
|
|
||||||
|
INVALID_ENTITY = [
|
||||||
|
"",
|
||||||
|
"你",
|
||||||
|
"他",
|
||||||
|
"她",
|
||||||
|
"它",
|
||||||
|
"我们",
|
||||||
|
"你们",
|
||||||
|
"他们",
|
||||||
|
"她们",
|
||||||
|
"它们",
|
||||||
|
]
|
||||||
|
|
||||||
|
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||||
|
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||||
|
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||||
|
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
|
|
||||||
|
qa_manager = None
|
||||||
|
inspire_manager = None
|
||||||
|
|
||||||
|
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
|
||||||
|
# 检查LPMM知识库是否启用
|
||||||
|
if global_config.lpmm_knowledge.enable:
|
||||||
|
logger.info("正在初始化Mai-LPMM")
|
||||||
|
logger.info("创建LLM客户端")
|
||||||
|
|
||||||
|
# 初始化Embedding库
|
||||||
|
embed_manager = EmbeddingManager()
|
||||||
|
logger.info("正在从文件加载Embedding库")
|
||||||
|
try:
|
||||||
|
embed_manager.load_from_file()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
||||||
|
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||||
|
logger.info("Embedding库加载完成")
|
||||||
|
# 初始化KG
|
||||||
|
kg_manager = KGManager()
|
||||||
|
logger.info("正在从文件加载KG")
|
||||||
|
try:
|
||||||
|
kg_manager.load_from_file()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
||||||
|
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||||
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
|
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||||
|
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||||
|
|
||||||
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
|
# 使用与EmbeddingStore中一致的命名空间格式
|
||||||
|
key = f"paragraph-{pg_hash}"
|
||||||
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
global qa_manager
|
||||||
|
# 问答系统(用于知识库)
|
||||||
|
qa_manager = QAManager(
|
||||||
|
embed_manager,
|
||||||
|
kg_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # 记忆激活(用于记忆库)
|
||||||
|
# global inspire_manager
|
||||||
|
# inspire_manager = MemoryActiveManager(
|
||||||
|
# embed_manager,
|
||||||
|
# llm_client_list[global_config["embedding"]["provider"]],
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
logger.info("LPMM知识库已禁用,跳过初始化")
|
||||||
|
# 创建空的占位符对象,避免导入错误
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from . import prompt_template
|
from . import prompt_template
|
||||||
from .knowledge_lib import INVALID_ENTITY
|
from . import INVALID_ENTITY
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
|||||||
@@ -1,80 +0,0 @@
|
|||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
|
||||||
from src.chat.knowledge.qa_manager import QAManager
|
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
|
||||||
from src.chat.knowledge.global_logger import logger
|
|
||||||
from src.config.config import global_config
|
|
||||||
import os
|
|
||||||
|
|
||||||
INVALID_ENTITY = [
|
|
||||||
"",
|
|
||||||
"你",
|
|
||||||
"他",
|
|
||||||
"她",
|
|
||||||
"它",
|
|
||||||
"我们",
|
|
||||||
"你们",
|
|
||||||
"他们",
|
|
||||||
"她们",
|
|
||||||
"它们",
|
|
||||||
]
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
|
||||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
|
||||||
|
|
||||||
|
|
||||||
qa_manager = None
|
|
||||||
inspire_manager = None
|
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
|
||||||
if global_config.lpmm_knowledge.enable:
|
|
||||||
logger.info("正在初始化Mai-LPMM")
|
|
||||||
logger.info("创建LLM客户端")
|
|
||||||
|
|
||||||
# 初始化Embedding库
|
|
||||||
embed_manager = EmbeddingManager()
|
|
||||||
logger.info("正在从文件加载Embedding库")
|
|
||||||
try:
|
|
||||||
embed_manager.load_from_file()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}")
|
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("Embedding库加载完成")
|
|
||||||
# 初始化KG
|
|
||||||
kg_manager = KGManager()
|
|
||||||
logger.info("正在从文件加载KG")
|
|
||||||
try:
|
|
||||||
kg_manager.load_from_file()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}")
|
|
||||||
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
|
||||||
logger.info("KG加载完成")
|
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
|
||||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
|
||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
|
||||||
# 使用与EmbeddingStore中一致的命名空间格式
|
|
||||||
key = f"paragraph-{pg_hash}"
|
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
|
||||||
|
|
||||||
# 问答系统(用于知识库)
|
|
||||||
qa_manager = QAManager(
|
|
||||||
embed_manager,
|
|
||||||
kg_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# # 记忆激活(用于记忆库)
|
|
||||||
# inspire_manager = MemoryActiveManager(
|
|
||||||
# embed_manager,
|
|
||||||
# llm_client_list[global_config["embedding"]["provider"]],
|
|
||||||
# )
|
|
||||||
else:
|
|
||||||
logger.info("LPMM知识库已禁用,跳过初始化")
|
|
||||||
# 创建空的占位符对象,避免导入错误
|
|
||||||
@@ -4,7 +4,7 @@ import glob
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
from . import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||||
# from src.manager.local_store_manager import local_storage
|
# from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ class QAManager:
|
|||||||
for res in relation_search_res:
|
for res in relation_search_res:
|
||||||
if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
|
if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
|
||||||
rel_str = store_item.str
|
rel_str = store_item.str
|
||||||
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
logger.info(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
|
||||||
|
|
||||||
# TODO: 使用LLM过滤三元组结果
|
# TODO: 使用LLM过滤三元组结果
|
||||||
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
# logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
|
||||||
@@ -94,7 +94,7 @@ class QAManager:
|
|||||||
|
|
||||||
for res in result:
|
for res in result:
|
||||||
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
|
||||||
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
|
||||||
|
|
||||||
return result, ppr_node_weights
|
return result, ppr_node_weights
|
||||||
|
|
||||||
|
|||||||
@@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
|
|||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
norm1 = np.linalg.norm(v1)
|
norm1 = np.linalg.norm(v1)
|
||||||
norm2 = np.linalg.norm(v2)
|
norm2 = np.linalg.norm(v2)
|
||||||
if norm1 == 0 or norm2 == 0:
|
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||||
return 0
|
|
||||||
return dot_product / (norm1 * norm2)
|
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -142,11 +140,10 @@ class MemoryGraph:
|
|||||||
# 获取当前节点的记忆项
|
# 获取当前节点的记忆项
|
||||||
node_data = self.get_dot(topic)
|
node_data = self.get_dot(topic)
|
||||||
if node_data:
|
if node_data:
|
||||||
concept, data = node_data
|
_, data = node_data
|
||||||
if "memory_items" in data:
|
if "memory_items" in data:
|
||||||
memory_items = data["memory_items"]
|
|
||||||
# 直接使用完整的记忆内容
|
# 直接使用完整的记忆内容
|
||||||
if memory_items:
|
if memory_items := data["memory_items"]:
|
||||||
first_layer_items.append(memory_items)
|
first_layer_items.append(memory_items)
|
||||||
|
|
||||||
# 只在depth=2时获取第二层记忆
|
# 只在depth=2时获取第二层记忆
|
||||||
@@ -154,11 +151,10 @@ class MemoryGraph:
|
|||||||
# 获取相邻节点的记忆项
|
# 获取相邻节点的记忆项
|
||||||
for neighbor in neighbors:
|
for neighbor in neighbors:
|
||||||
if node_data := self.get_dot(neighbor):
|
if node_data := self.get_dot(neighbor):
|
||||||
concept, data = node_data
|
_, data = node_data
|
||||||
if "memory_items" in data:
|
if "memory_items" in data:
|
||||||
memory_items = data["memory_items"]
|
|
||||||
# 直接使用完整的记忆内容
|
# 直接使用完整的记忆内容
|
||||||
if memory_items:
|
if memory_items := data["memory_items"]:
|
||||||
second_layer_items.append(memory_items)
|
second_layer_items.append(memory_items)
|
||||||
|
|
||||||
return first_layer_items, second_layer_items
|
return first_layer_items, second_layer_items
|
||||||
@@ -224,27 +220,17 @@ class MemoryGraph:
|
|||||||
# 获取话题节点数据
|
# 获取话题节点数据
|
||||||
node_data = self.G.nodes[topic]
|
node_data = self.G.nodes[topic]
|
||||||
|
|
||||||
|
# 删除整个节点
|
||||||
|
self.G.remove_node(topic)
|
||||||
# 如果节点存在memory_items
|
# 如果节点存在memory_items
|
||||||
if "memory_items" in node_data:
|
if "memory_items" in node_data:
|
||||||
memory_items = node_data["memory_items"]
|
if memory_items := node_data["memory_items"]:
|
||||||
|
|
||||||
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
|
|
||||||
if memory_items:
|
|
||||||
# 删除整个节点
|
|
||||||
self.G.remove_node(topic)
|
|
||||||
return (
|
return (
|
||||||
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||||
if len(memory_items) > 50
|
if len(memory_items) > 50
|
||||||
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||||
)
|
)
|
||||||
else:
|
return None
|
||||||
# 如果没有记忆项,删除该节点
|
|
||||||
self.G.remove_node(topic)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
# 如果没有memory_items字段,删除该节点
|
|
||||||
self.G.remove_node(topic)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# 海马体
|
# 海马体
|
||||||
@@ -392,9 +378,8 @@ class Hippocampus:
|
|||||||
# 如果相似度超过阈值,获取该节点的记忆
|
# 如果相似度超过阈值,获取该节点的记忆
|
||||||
if similarity >= 0.3: # 可以调整这个阈值
|
if similarity >= 0.3: # 可以调整这个阈值
|
||||||
node_data = self.memory_graph.G.nodes[node]
|
node_data = self.memory_graph.G.nodes[node]
|
||||||
memory_items = node_data.get("memory_items", "")
|
|
||||||
# 直接使用完整的记忆内容
|
# 直接使用完整的记忆内容
|
||||||
if memory_items:
|
if memory_items := node_data.get("memory_items", ""):
|
||||||
memories.append((node, memory_items, similarity))
|
memories.append((node, memory_items, similarity))
|
||||||
|
|
||||||
# 按相似度降序排序
|
# 按相似度降序排序
|
||||||
@@ -411,7 +396,7 @@ class Hippocampus:
|
|||||||
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
|
||||||
text_length = len(text)
|
text_length = len(text)
|
||||||
@@ -587,7 +572,7 @@ class Hippocampus:
|
|||||||
unique_memories = []
|
unique_memories = []
|
||||||
for topic, memory_items, activation_value in all_memories:
|
for topic, memory_items, activation_value in all_memories:
|
||||||
# memory_items现在是完整的字符串格式
|
# memory_items现在是完整的字符串格式
|
||||||
memory = memory_items if memory_items else ""
|
memory = memory_items or ""
|
||||||
if memory not in seen_memories:
|
if memory not in seen_memories:
|
||||||
seen_memories.add(memory)
|
seen_memories.add(memory)
|
||||||
unique_memories.append((topic, memory_items, activation_value))
|
unique_memories.append((topic, memory_items, activation_value))
|
||||||
@@ -599,7 +584,7 @@ class Hippocampus:
|
|||||||
result = []
|
result = []
|
||||||
for topic, memory_items, _ in unique_memories:
|
for topic, memory_items, _ in unique_memories:
|
||||||
# memory_items现在是完整的字符串格式
|
# memory_items现在是完整的字符串格式
|
||||||
memory = memory_items if memory_items else ""
|
memory = memory_items or ""
|
||||||
result.append((topic, memory))
|
result.append((topic, memory))
|
||||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||||
|
|
||||||
@@ -1435,13 +1420,11 @@ class HippocampusManager:
|
|||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
try:
|
try:
|
||||||
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
|
return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||||
text, max_depth, fast_retrieval
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"文本产生激活值失败: {e}")
|
logger.error(f"文本产生激活值失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return 0.0, [], []
|
return 0.0, [], []
|
||||||
|
|
||||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||||
"""从关键词获取相关记忆的公共接口"""
|
"""从关键词获取相关记忆的公共接口"""
|
||||||
@@ -1473,6 +1456,7 @@ class MemoryBuilder:
|
|||||||
self.last_processed_time: float = 0.0
|
self.last_processed_time: float = 0.0
|
||||||
|
|
||||||
def should_trigger_memory_build(self) -> bool:
|
def should_trigger_memory_build(self) -> bool:
|
||||||
|
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
|
||||||
"""检查是否应该触发记忆构建"""
|
"""检查是否应该触发记忆构建"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from datetime import datetime, timedelta
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Memory # Peewee Models导入
|
from src.common.database.database_model import Memory # Peewee Models导入
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config, global_config
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -42,7 +42,7 @@ class InstantMemory:
|
|||||||
request_type="memory.summary",
|
request_type="memory.summary",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def if_need_build(self, text):
|
async def if_need_build(self, text: str):
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
请判断以下内容中是否有值得记忆的信息,如果有,请输出1,否则输出0
|
||||||
{text}
|
{text}
|
||||||
@@ -51,8 +51,9 @@ class InstantMemory:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
print(prompt)
|
if global_config.debug.show_prompt:
|
||||||
print(response)
|
print(prompt)
|
||||||
|
print(response)
|
||||||
|
|
||||||
return "1" in response
|
return "1" in response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -94,7 +95,7 @@ class InstantMemory:
|
|||||||
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def create_and_store_memory(self, text):
|
async def create_and_store_memory(self, text: str):
|
||||||
if_need = await self.if_need_build(text)
|
if_need = await self.if_need_build(text)
|
||||||
if if_need:
|
if if_need:
|
||||||
logger.info(f"需要记忆:{text}")
|
logger.info(f"需要记忆:{text}")
|
||||||
@@ -126,24 +127,25 @@ class InstantMemory:
|
|||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请根据以下发言内容,判断是否需要提取记忆
|
请根据以下发言内容,判断是否需要提取记忆
|
||||||
{target}
|
{target}
|
||||||
请用json格式输出,包含以下字段:
|
请用json格式输出,包含以下字段:
|
||||||
其中,time的要求是:
|
其中,time的要求是:
|
||||||
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
||||||
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
||||||
可以选择留空进行模糊搜索
|
可以选择留空进行模糊搜索
|
||||||
{{
|
{{
|
||||||
"need_memory": 1,
|
"need_memory": 1,
|
||||||
"keywords": "希望获取的记忆关键词,用/划分",
|
"keywords": "希望获取的记忆关键词,用/划分",
|
||||||
"time": "希望获取的记忆大致时间"
|
"time": "希望获取的记忆大致时间"
|
||||||
}}
|
}}
|
||||||
请只输出json格式,不要输出其他多余内容
|
请只输出json格式,不要输出其他多余内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
print(prompt)
|
if global_config.debug.show_prompt:
|
||||||
print(response)
|
print(prompt)
|
||||||
|
print(response)
|
||||||
if not response:
|
if not response:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -145,7 +145,7 @@ class ChatBot:
|
|||||||
logger.error(f"处理命令时出错: {e}")
|
logger.error(f"处理命令时出错: {e}")
|
||||||
return False, None, True # 出错时继续处理消息
|
return False, None, True # 出错时继续处理消息
|
||||||
|
|
||||||
async def hanle_notice_message(self, message: MessageRecv):
|
async def handle_notice_message(self, message: MessageRecv):
|
||||||
if message.message_info.message_id == "notice":
|
if message.message_info.message_id == "notice":
|
||||||
message.is_notify = True
|
message.is_notify = True
|
||||||
logger.info("notice消息")
|
logger.info("notice消息")
|
||||||
@@ -212,7 +212,7 @@ class ChatBot:
|
|||||||
# logger.debug(str(message_data))
|
# logger.debug(str(message_data))
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
|
|
||||||
if await self.hanle_notice_message(message):
|
if await self.handle_notice_message(message):
|
||||||
# return
|
# return
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class MessageRecv(Message):
|
|||||||
self.priority_mode = "interest"
|
self.priority_mode = "interest"
|
||||||
self.priority_info = None
|
self.priority_info = None
|
||||||
self.interest_value: float = None # type: ignore
|
self.interest_value: float = None # type: ignore
|
||||||
|
|
||||||
self.key_words = []
|
self.key_words = []
|
||||||
self.key_words_lite = []
|
self.key_words_lite = []
|
||||||
|
|
||||||
@@ -213,9 +213,9 @@ class MessageRecvS4U(MessageRecv):
|
|||||||
self.is_screen = False
|
self.is_screen = False
|
||||||
self.is_internal = False
|
self.is_internal = False
|
||||||
self.voice_done = None
|
self.voice_done = None
|
||||||
|
|
||||||
self.chat_info = None
|
self.chat_info = None
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||||
|
|
||||||
@@ -420,7 +420,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
reply_to: Optional[str] = None,
|
reply_to: Optional[str] = None,
|
||||||
selected_expressions:List[int] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -445,7 +445,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.display_message = display_message
|
self.display_message = display_message
|
||||||
|
|
||||||
self.interest_value = 0.0
|
self.interest_value = 0.0
|
||||||
|
|
||||||
self.selected_expressions = selected_expressions
|
self.selected_expressions = selected_expressions
|
||||||
|
|
||||||
def build_reply(self):
|
def build_reply(self):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from typing import Dict, Optional, Type
|
|||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
from src.plugin_system.base.component_types import ComponentType, ActionInfo
|
||||||
from src.plugin_system.base.base_action import BaseAction
|
from src.plugin_system.base.base_action import BaseAction
|
||||||
@@ -37,7 +38,7 @@ class ActionManager:
|
|||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
log_prefix: str,
|
log_prefix: str,
|
||||||
shutting_down: bool = False,
|
shutting_down: bool = False,
|
||||||
action_message: Optional[dict] = None,
|
action_message: Optional[DatabaseMessages] = None,
|
||||||
) -> Optional[BaseAction]:
|
) -> Optional[BaseAction]:
|
||||||
"""
|
"""
|
||||||
创建动作处理器实例
|
创建动作处理器实例
|
||||||
@@ -83,7 +84,7 @@ class ActionManager:
|
|||||||
log_prefix=log_prefix,
|
log_prefix=log_prefix,
|
||||||
shutting_down=shutting_down,
|
shutting_down=shutting_down,
|
||||||
plugin_config=plugin_config,
|
plugin_config=plugin_config,
|
||||||
action_message=action_message,
|
action_message=action_message.flatten() if action_message else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建Action实例成功: {action_name}")
|
logger.debug(f"创建Action实例成功: {action_name}")
|
||||||
@@ -123,4 +124,4 @@ class ActionManager:
|
|||||||
"""恢复到默认动作集"""
|
"""恢复到默认动作集"""
|
||||||
actions_to_restore = list(self._using_actions.keys())
|
actions_to_restore = list(self._using_actions.keys())
|
||||||
self._using_actions = component_registry.get_default_actions()
|
self._using_actions = component_registry.get_default_actions()
|
||||||
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, Any, Optional, Tuple, List
|
from typing import Dict, Optional, Tuple, List
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
@@ -9,6 +9,8 @@ from json_repair import repair_json
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_actions,
|
build_readable_actions,
|
||||||
@@ -97,7 +99,9 @@ class ActionPlanner:
|
|||||||
self.plan_retry_count = 0
|
self.plan_retry_count = 0
|
||||||
self.max_plan_retries = 3
|
self.max_plan_retries = 3
|
||||||
|
|
||||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
def find_message_by_id(
|
||||||
|
self, message_id: str, message_id_list: List[DatabaseMessages]
|
||||||
|
) -> Optional[DatabaseMessages]:
|
||||||
# sourcery skip: use-next
|
# sourcery skip: use-next
|
||||||
"""
|
"""
|
||||||
根据message_id从message_id_list中查找对应的原始消息
|
根据message_id从message_id_list中查找对应的原始消息
|
||||||
@@ -110,37 +114,37 @@ class ActionPlanner:
|
|||||||
找到的原始消息字典,如果未找到则返回None
|
找到的原始消息字典,如果未找到则返回None
|
||||||
"""
|
"""
|
||||||
for item in message_id_list:
|
for item in message_id_list:
|
||||||
if item.get("id") == message_id:
|
if item.message_id == message_id:
|
||||||
return item.get("message")
|
return item
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
|
def get_latest_message(self, message_id_list: List[DatabaseMessages]) -> Optional[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
获取消息列表中的最新消息
|
获取消息列表中的最新消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
message_id_list: 消息ID列表,格式为[{'id': str, 'message': dict}, ...]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
最新的消息字典,如果列表为空则返回None
|
最新的消息字典,如果列表为空则返回None
|
||||||
"""
|
"""
|
||||||
return message_id_list[-1].get("message") if message_id_list else None
|
return message_id_list[-1] if message_id_list else None
|
||||||
|
|
||||||
async def plan(
|
async def plan(
|
||||||
self,
|
self,
|
||||||
mode: ChatMode = ChatMode.FOCUS,
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
loop_start_time:float = 0.0,
|
loop_start_time: float = 0.0,
|
||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
) -> Tuple[List[ActionPlannerInfo], Optional[DatabaseMessages]]:
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
action = "no_action" # 默认动作
|
action: str = "no_action" # 默认动作
|
||||||
reasoning = "规划器初始化默认"
|
reasoning: str = "规划器初始化默认"
|
||||||
action_data = {}
|
action_data = {}
|
||||||
current_available_actions: Dict[str, ActionInfo] = {}
|
current_available_actions: Dict[str, ActionInfo] = {}
|
||||||
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
target_message: Optional[DatabaseMessages] = None # 初始化target_message变量
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
message_id_list: list = []
|
message_id_list: list = []
|
||||||
|
|
||||||
@@ -208,19 +212,21 @@ class ActionPlanner:
|
|||||||
# 如果获取的target_message为None,输出warning并重新plan
|
# 如果获取的target_message为None,输出warning并重新plan
|
||||||
if target_message is None:
|
if target_message is None:
|
||||||
self.plan_retry_count += 1
|
self.plan_retry_count += 1
|
||||||
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
|
logger.warning(
|
||||||
|
f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}"
|
||||||
|
)
|
||||||
# 仍有重试次数
|
# 仍有重试次数
|
||||||
if self.plan_retry_count < self.max_plan_retries:
|
if self.plan_retry_count < self.max_plan_retries:
|
||||||
# 递归重新plan
|
# 递归重新plan
|
||||||
return await self.plan(mode, loop_start_time, available_actions)
|
return await self.plan(mode, loop_start_time, available_actions)
|
||||||
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message")
|
logger.error(
|
||||||
|
f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败,选择最新消息作为target_message"
|
||||||
|
)
|
||||||
target_message = self.get_latest_message(message_id_list)
|
target_message = self.get_latest_message(message_id_list)
|
||||||
self.plan_retry_count = 0 # 重置计数器
|
self.plan_retry_count = 0 # 重置计数器
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if action != "no_action" and action != "reply" and action not in current_available_actions:
|
if action != "no_action" and action != "reply" and action not in current_available_actions:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
|
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
|
||||||
@@ -244,38 +250,37 @@ class ActionPlanner:
|
|||||||
if mode == ChatMode.NORMAL and action in current_available_actions:
|
if mode == ChatMode.NORMAL and action in current_available_actions:
|
||||||
is_parallel = current_available_actions[action].parallel_action
|
is_parallel = current_available_actions[action].parallel_action
|
||||||
|
|
||||||
|
|
||||||
action_data["loop_start_time"] = loop_start_time
|
action_data["loop_start_time"] = loop_start_time
|
||||||
|
|
||||||
actions = [
|
actions = [
|
||||||
{
|
ActionPlannerInfo(
|
||||||
"action_type": action,
|
action_type=action,
|
||||||
"reasoning": reasoning,
|
reasoning=reasoning,
|
||||||
"action_data": action_data,
|
action_data=action_data,
|
||||||
"action_message": target_message,
|
action_message=target_message,
|
||||||
"available_actions": available_actions,
|
available_actions=available_actions,
|
||||||
}
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if action != "reply" and is_parallel:
|
if action != "reply" and is_parallel:
|
||||||
actions.append({
|
actions.append(
|
||||||
"action_type": "reply",
|
ActionPlannerInfo(
|
||||||
"action_message": target_message,
|
action_type="reply",
|
||||||
"available_actions": available_actions
|
action_message=target_message,
|
||||||
})
|
available_actions=available_actions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return actions,target_message
|
return actions, target_message
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def build_planner_prompt(
|
async def build_planner_prompt(
|
||||||
self,
|
self,
|
||||||
is_group_chat: bool, # Now passed as argument
|
is_group_chat: bool, # Now passed as argument
|
||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
current_available_actions: Dict[str, ActionInfo],
|
current_available_actions: Dict[str, ActionInfo],
|
||||||
refresh_time :bool = False,
|
refresh_time: bool = False,
|
||||||
mode: ChatMode = ChatMode.FOCUS,
|
mode: ChatMode = ChatMode.FOCUS,
|
||||||
) -> tuple[str, list]: # sourcery skip: use-join
|
) -> tuple[str, List[DatabaseMessages]]: # sourcery skip: use-join
|
||||||
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
|
||||||
try:
|
try:
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
@@ -305,13 +310,12 @@ class ActionPlanner:
|
|||||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
||||||
if refresh_time:
|
if refresh_time:
|
||||||
self.last_obs_time_mark = time.time()
|
self.last_obs_time_mark = time.time()
|
||||||
|
|
||||||
mentioned_bonus = ""
|
mentioned_bonus = ""
|
||||||
if global_config.chat.mentioned_bot_inevitable_reply:
|
if global_config.chat.mentioned_bot_inevitable_reply:
|
||||||
mentioned_bonus = "\n- 有人提到你"
|
mentioned_bonus = "\n- 有人提到你"
|
||||||
if global_config.chat.at_bot_inevitable_reply:
|
if global_config.chat.at_bot_inevitable_reply:
|
||||||
mentioned_bonus = "\n- 有人提到你,或者at你"
|
mentioned_bonus = "\n- 有人提到你,或者at你"
|
||||||
|
|
||||||
|
|
||||||
if mode == ChatMode.FOCUS:
|
if mode == ChatMode.FOCUS:
|
||||||
no_action_block = """
|
no_action_block = """
|
||||||
@@ -332,7 +336,7 @@ class ActionPlanner:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
chat_context_description = "你现在正在一个群聊中"
|
chat_context_description = "你现在正在一个群聊中"
|
||||||
chat_target_name = None
|
chat_target_name = None
|
||||||
if not is_group_chat and chat_target_info:
|
if not is_group_chat and chat_target_info:
|
||||||
chat_target_name = (
|
chat_target_name = (
|
||||||
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
|
||||||
@@ -388,7 +392,7 @@ class ActionPlanner:
|
|||||||
action_options_text=action_options_block,
|
action_options_text=action_options_block,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
identity_block=identity_block,
|
identity_block=identity_block,
|
||||||
plan_style = global_config.personality.plan_style
|
plan_style=global_config.personality.plan_style,
|
||||||
)
|
)
|
||||||
return prompt, message_id_list
|
return prompt, message_id_list
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from datetime import datetime
|
|||||||
from src.mais4u.mai_think import mai_thinking_manager
|
from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.individuality.individuality import get_individuality
|
from src.individuality.individuality import get_individuality
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -21,7 +22,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
|||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
replace_user_references_sync,
|
replace_user_references,
|
||||||
)
|
)
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.memory_system.memory_activator import MemoryActivator
|
from src.chat.memory_system.memory_activator import MemoryActivator
|
||||||
@@ -157,12 +158,12 @@ class DefaultReplyer:
|
|||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
reply_reason: str = "",
|
reply_reason: str = "",
|
||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
chosen_actions: Optional[List[Dict[str, Any]]] = None,
|
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[DatabaseMessages] = None,
|
||||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]:
|
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]:
|
||||||
# sourcery skip: merge-nested-ifs
|
# sourcery skip: merge-nested-ifs
|
||||||
"""
|
"""
|
||||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||||
@@ -181,7 +182,7 @@ class DefaultReplyer:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
prompt = None
|
prompt = None
|
||||||
selected_expressions = None
|
selected_expressions: Optional[List[int]] = None
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
try:
|
try:
|
||||||
@@ -374,7 +375,12 @@ class DefaultReplyer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if global_config.memory.enable_instant_memory:
|
if global_config.memory.enable_instant_memory:
|
||||||
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
|
chat_history_str = build_readable_messages(
|
||||||
|
messages=chat_history,
|
||||||
|
replace_bot_name=True,
|
||||||
|
timestamp_mode="normal"
|
||||||
|
)
|
||||||
|
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str))
|
||||||
|
|
||||||
instant_memory = await self.instant_memory.get_memory(target)
|
instant_memory = await self.instant_memory.get_memory(target)
|
||||||
logger.info(f"即时记忆:{instant_memory}")
|
logger.info(f"即时记忆:{instant_memory}")
|
||||||
@@ -527,7 +533,7 @@ class DefaultReplyer:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
|
||||||
"""
|
"""
|
||||||
core_dialogue_list = []
|
core_dialogue_list: List[DatabaseMessages] = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||||
@@ -559,7 +565,7 @@ class DefaultReplyer:
|
|||||||
if core_dialogue_list:
|
if core_dialogue_list:
|
||||||
# 检查最新五条消息中是否包含bot自己说的消息
|
# 检查最新五条消息中是否包含bot自己说的消息
|
||||||
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
|
||||||
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
|
has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
|
||||||
|
|
||||||
# logger.info(f"最新五条消息:{latest_5_messages}")
|
# logger.info(f"最新五条消息:{latest_5_messages}")
|
||||||
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
# logger.info(f"最新五条消息中是否包含bot自己说的消息:{has_bot_message}")
|
||||||
@@ -634,7 +640,7 @@ class DefaultReplyer:
|
|||||||
return mai_think
|
return mai_think
|
||||||
|
|
||||||
async def build_actions_prompt(
|
async def build_actions_prompt(
|
||||||
self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None
|
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建动作提示"""
|
"""构建动作提示"""
|
||||||
|
|
||||||
@@ -646,20 +652,21 @@ class DefaultReplyer:
|
|||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
choosen_action_descriptions = ""
|
chosen_action_descriptions = ""
|
||||||
if choosen_actions:
|
if chosen_actions_info:
|
||||||
for action in choosen_actions:
|
for action_plan_info in chosen_actions_info:
|
||||||
action_name = action.get("action_type", "unknown_action")
|
action_name = action_plan_info.action_type
|
||||||
if action_name == "reply":
|
if action_name == "reply":
|
||||||
continue
|
continue
|
||||||
action_description = action.get("reason", "无描述")
|
if action := available_actions.get(action_name):
|
||||||
reasoning = action.get("reasoning", "无原因")
|
action_description = action.description or "无描述"
|
||||||
|
reasoning = action_plan_info.reasoning or "无原因"
|
||||||
|
|
||||||
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
chosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
|
||||||
|
|
||||||
if choosen_action_descriptions:
|
if chosen_action_descriptions:
|
||||||
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
|
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
|
||||||
action_descriptions += choosen_action_descriptions
|
action_descriptions += chosen_action_descriptions
|
||||||
|
|
||||||
return action_descriptions
|
return action_descriptions
|
||||||
|
|
||||||
@@ -668,9 +675,9 @@ class DefaultReplyer:
|
|||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
reply_reason: str = "",
|
reply_reason: str = "",
|
||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
chosen_actions: Optional[List[Dict[str, Any]]] = None,
|
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
|
||||||
enable_tool: bool = True,
|
enable_tool: bool = True,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional[DatabaseMessages] = None,
|
||||||
) -> Tuple[str, List[int]]:
|
) -> Tuple[str, List[int]]:
|
||||||
"""
|
"""
|
||||||
构建回复器上下文
|
构建回复器上下文
|
||||||
@@ -694,11 +701,11 @@ class DefaultReplyer:
|
|||||||
platform = chat_stream.platform
|
platform = chat_stream.platform
|
||||||
|
|
||||||
if reply_message:
|
if reply_message:
|
||||||
user_id = reply_message.get("user_id", "")
|
user_id = reply_message.user_info.user_id
|
||||||
person = Person(platform=platform, user_id=user_id)
|
person = Person(platform=platform, user_id=user_id)
|
||||||
person_name = person.person_name or user_id
|
person_name = person.person_name or user_id
|
||||||
sender = person_name
|
sender = person_name
|
||||||
target = reply_message.get("processed_plain_text")
|
target = reply_message.processed_plain_text
|
||||||
else:
|
else:
|
||||||
person_name = "用户"
|
person_name = "用户"
|
||||||
sender = "用户"
|
sender = "用户"
|
||||||
@@ -710,7 +717,7 @@ class DefaultReplyer:
|
|||||||
else:
|
else:
|
||||||
mood_prompt = ""
|
mood_prompt = ""
|
||||||
|
|
||||||
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
|
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||||
|
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@@ -774,11 +781,13 @@ class DefaultReplyer:
|
|||||||
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
|
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
|
||||||
|
|
||||||
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
expression_habits_block, selected_expressions = results_dict["expression_habits"]
|
||||||
relation_info = results_dict["relation_info"]
|
expression_habits_block: str
|
||||||
memory_block = results_dict["memory_block"]
|
selected_expressions: List[int]
|
||||||
tool_info = results_dict["tool_info"]
|
relation_info: str = results_dict["relation_info"]
|
||||||
prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果
|
memory_block: str = results_dict["memory_block"]
|
||||||
actions_info = results_dict["actions_info"]
|
tool_info: str = results_dict["tool_info"]
|
||||||
|
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
|
||||||
|
actions_info: str = results_dict["actions_info"]
|
||||||
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
|
||||||
|
|
||||||
if extra_info:
|
if extra_info:
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("chat_message_builder")
|
logger = get_logger("chat_message_builder")
|
||||||
|
|
||||||
|
|
||||||
def replace_user_references_sync(
|
def replace_user_references(
|
||||||
content: str,
|
content: Optional[str],
|
||||||
platform: str,
|
platform: str,
|
||||||
name_resolver: Optional[Callable[[str, str], str]] = None,
|
name_resolver: Optional[Callable[[str, str], str]] = None,
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
@@ -38,6 +38,8 @@ def replace_user_references_sync(
|
|||||||
Returns:
|
Returns:
|
||||||
str: 处理后的内容字符串
|
str: 处理后的内容字符串
|
||||||
"""
|
"""
|
||||||
|
if not content:
|
||||||
|
return ""
|
||||||
if name_resolver is None:
|
if name_resolver is None:
|
||||||
|
|
||||||
def default_resolver(platform: str, user_id: str) -> str:
|
def default_resolver(platform: str, user_id: str) -> str:
|
||||||
@@ -93,80 +95,6 @@ def replace_user_references_sync(
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def replace_user_references_async(
|
|
||||||
content: str,
|
|
||||||
platform: str,
|
|
||||||
name_resolver: Optional[Callable[[str, str], Any]] = None,
|
|
||||||
replace_bot_name: bool = True,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
替换内容中的用户引用格式,包括回复<aaa:bbb>和@<aaa:bbb>格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 要处理的内容字符串
|
|
||||||
platform: 平台标识
|
|
||||||
name_resolver: 名称解析函数,接收(platform, user_id)参数,返回用户名称
|
|
||||||
如果为None,则使用默认的person_info_manager
|
|
||||||
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 处理后的内容字符串
|
|
||||||
"""
|
|
||||||
if name_resolver is None:
|
|
||||||
|
|
||||||
async def default_resolver(platform: str, user_id: str) -> str:
|
|
||||||
# 检查是否是机器人自己
|
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
|
||||||
return f"{global_config.bot.nickname}(你)"
|
|
||||||
person = Person(platform=platform, user_id=user_id)
|
|
||||||
return person.person_name or user_id # type: ignore
|
|
||||||
|
|
||||||
name_resolver = default_resolver
|
|
||||||
|
|
||||||
# 处理回复<aaa:bbb>格式
|
|
||||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
|
||||||
match = re.search(reply_pattern, content)
|
|
||||||
if match:
|
|
||||||
aaa = match.group(1)
|
|
||||||
bbb = match.group(2)
|
|
||||||
try:
|
|
||||||
# 检查是否是机器人自己
|
|
||||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
|
||||||
reply_person_name = f"{global_config.bot.nickname}(你)"
|
|
||||||
else:
|
|
||||||
reply_person_name = await name_resolver(platform, bbb) or aaa
|
|
||||||
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
|
|
||||||
except Exception:
|
|
||||||
# 如果解析失败,使用原始昵称
|
|
||||||
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
|
|
||||||
|
|
||||||
# 处理@<aaa:bbb>格式
|
|
||||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
|
||||||
at_matches = list(re.finditer(at_pattern, content))
|
|
||||||
if at_matches:
|
|
||||||
new_content = ""
|
|
||||||
last_end = 0
|
|
||||||
for m in at_matches:
|
|
||||||
new_content += content[last_end : m.start()]
|
|
||||||
aaa = m.group(1)
|
|
||||||
bbb = m.group(2)
|
|
||||||
try:
|
|
||||||
# 检查是否是机器人自己
|
|
||||||
if replace_bot_name and bbb == global_config.bot.qq_account:
|
|
||||||
at_person_name = f"{global_config.bot.nickname}(你)"
|
|
||||||
else:
|
|
||||||
at_person_name = await name_resolver(platform, bbb) or aaa
|
|
||||||
new_content += f"@{at_person_name}"
|
|
||||||
except Exception:
|
|
||||||
# 如果解析失败,使用原始昵称
|
|
||||||
new_content += f"@{aaa}"
|
|
||||||
last_end = m.end()
|
|
||||||
new_content += content[last_end:]
|
|
||||||
content = new_content
|
|
||||||
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"):
|
def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"):
|
||||||
"""
|
"""
|
||||||
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
获取从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
@@ -498,7 +426,7 @@ def _build_readable_messages_internal(
|
|||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
|
|
||||||
# 使用独立函数处理用户引用格式
|
# 使用独立函数处理用户引用格式
|
||||||
if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name):
|
if content := replace_user_references(content, platform, replace_bot_name=replace_bot_name):
|
||||||
detailed_messages_raw.append((timestamp, person_name, content, False))
|
detailed_messages_raw.append((timestamp, person_name, content, False))
|
||||||
|
|
||||||
if not detailed_messages_raw:
|
if not detailed_messages_raw:
|
||||||
@@ -658,7 +586,10 @@ async def build_readable_messages_with_list(
|
|||||||
允许通过参数控制格式化行为。
|
允许通过参数控制格式化行为。
|
||||||
"""
|
"""
|
||||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
|
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
|
||||||
|
replace_bot_name,
|
||||||
|
timestamp_mode,
|
||||||
|
truncate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||||
@@ -725,19 +656,7 @@ def build_readable_messages(
|
|||||||
if not messages:
|
if not messages:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
copy_messages: List[MessageAndActionModel] = [
|
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
|
||||||
MessageAndActionModel(
|
|
||||||
msg.time,
|
|
||||||
msg.user_info.user_id,
|
|
||||||
msg.user_info.platform,
|
|
||||||
msg.user_info.user_nickname,
|
|
||||||
msg.user_info.user_cardname,
|
|
||||||
msg.processed_plain_text,
|
|
||||||
msg.display_message,
|
|
||||||
msg.chat_info.platform,
|
|
||||||
)
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
if show_actions and copy_messages:
|
if show_actions and copy_messages:
|
||||||
# 获取所有消息的时间范围
|
# 获取所有消息的时间范围
|
||||||
@@ -942,7 +861,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return "?"
|
return "?"
|
||||||
|
|
||||||
content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False)
|
content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
|
||||||
|
|
||||||
header = f"{anon_name}说 "
|
header = f"{anon_name}说 "
|
||||||
output_lines.append(header)
|
output_lines.append(header)
|
||||||
@@ -996,22 +915,3 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set) # 将集合转换为列表返回
|
||||||
|
|
||||||
|
|
||||||
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
|
|
||||||
"""
|
|
||||||
将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
MessageAndActionModel(
|
|
||||||
time=msg.time,
|
|
||||||
user_id=msg.user_info.user_id,
|
|
||||||
user_platform=msg.user_info.platform,
|
|
||||||
user_nickname=msg.user_info.user_nickname,
|
|
||||||
user_cardname=msg.user_info.user_cardname,
|
|
||||||
processed_plain_text=msg.processed_plain_text,
|
|
||||||
display_message=msg.display_message,
|
|
||||||
chat_info_platform=msg.chat_info.platform,
|
|
||||||
)
|
|
||||||
for msg in message
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from collections import Counter
|
|||||||
from typing import Optional, Tuple, Dict, List, Any
|
from typing import Optional, Tuple, Dict, List, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.info_data_model import TargetPersonInfo
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
@@ -641,6 +640,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
platform: str = chat_stream.platform
|
platform: str = chat_stream.platform
|
||||||
user_id: str = user_info.user_id # type: ignore
|
user_id: str = user_info.user_id # type: ignore
|
||||||
|
|
||||||
|
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
|
||||||
|
|
||||||
# Initialize target_info with basic info
|
# Initialize target_info with basic info
|
||||||
target_info = TargetPersonInfo(
|
target_info = TargetPersonInfo(
|
||||||
platform=platform,
|
platform=platform,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Any
|
from typing import Optional, Any, Dict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
@@ -157,3 +157,42 @@ class DatabaseMessages(BaseDataModel):
|
|||||||
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
|
||||||
# "interest_value must be a float or None"
|
# "interest_value must be a float or None"
|
||||||
# )
|
# )
|
||||||
|
def flatten(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
将消息数据模型转换为字典格式,便于存储或传输
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"message_id": self.message_id,
|
||||||
|
"time": self.time,
|
||||||
|
"chat_id": self.chat_id,
|
||||||
|
"reply_to": self.reply_to,
|
||||||
|
"interest_value": self.interest_value,
|
||||||
|
"key_words": self.key_words,
|
||||||
|
"key_words_lite": self.key_words_lite,
|
||||||
|
"is_mentioned": self.is_mentioned,
|
||||||
|
"processed_plain_text": self.processed_plain_text,
|
||||||
|
"display_message": self.display_message,
|
||||||
|
"priority_mode": self.priority_mode,
|
||||||
|
"priority_info": self.priority_info,
|
||||||
|
"additional_config": self.additional_config,
|
||||||
|
"is_emoji": self.is_emoji,
|
||||||
|
"is_picid": self.is_picid,
|
||||||
|
"is_command": self.is_command,
|
||||||
|
"is_notify": self.is_notify,
|
||||||
|
"selected_expressions": self.selected_expressions,
|
||||||
|
"user_id": self.user_info.user_id,
|
||||||
|
"user_nickname": self.user_info.user_nickname,
|
||||||
|
"user_cardname": self.user_info.user_cardname,
|
||||||
|
"user_platform": self.user_info.platform,
|
||||||
|
"chat_info_group_id": self.group_info.group_id if self.group_info else None,
|
||||||
|
"chat_info_group_name": self.group_info.group_name if self.group_info else None,
|
||||||
|
"chat_info_group_platform": self.group_info.group_platform if self.group_info else None,
|
||||||
|
"chat_info_stream_id": self.chat_info.stream_id,
|
||||||
|
"chat_info_platform": self.chat_info.platform,
|
||||||
|
"chat_info_create_time": self.chat_info.create_time,
|
||||||
|
"chat_info_last_active_time": self.chat_info.last_active_time,
|
||||||
|
"chat_info_user_platform": self.chat_info.user_info.platform,
|
||||||
|
"chat_info_user_id": self.chat_info.user_info.user_id,
|
||||||
|
"chat_info_user_nickname": self.chat_info.user_info.user_nickname,
|
||||||
|
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,12 +1,25 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, Dict, TYPE_CHECKING
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .database_data_model import DatabaseMessages
|
||||||
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TargetPersonInfo(BaseDataModel):
|
class TargetPersonInfo(BaseDataModel):
|
||||||
platform: str = field(default_factory=str)
|
platform: str = field(default_factory=str)
|
||||||
user_id: str = field(default_factory=str)
|
user_id: str = field(default_factory=str)
|
||||||
user_nickname: str = field(default_factory=str)
|
user_nickname: str = field(default_factory=str)
|
||||||
person_id: Optional[str] = None
|
person_id: Optional[str] = None
|
||||||
person_name: Optional[str] = None
|
person_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ActionPlannerInfo(BaseDataModel):
|
||||||
|
action_type: str = field(default_factory=str)
|
||||||
|
reasoning: Optional[str] = None
|
||||||
|
action_data: Optional[Dict] = None
|
||||||
|
action_message: Optional["DatabaseMessages"] = None
|
||||||
|
available_actions: Optional[Dict[str, "ActionInfo"]] = None
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageAndActionModel(BaseDataModel):
|
class MessageAndActionModel(BaseDataModel):
|
||||||
|
chat_id: str = field(default_factory=str)
|
||||||
time: float = field(default_factory=float)
|
time: float = field(default_factory=float)
|
||||||
user_id: str = field(default_factory=str)
|
user_id: str = field(default_factory=str)
|
||||||
user_platform: str = field(default_factory=str)
|
user_platform: str = field(default_factory=str)
|
||||||
@@ -15,3 +20,17 @@ class MessageAndActionModel(BaseDataModel):
|
|||||||
chat_info_platform: str = field(default_factory=str)
|
chat_info_platform: str = field(default_factory=str)
|
||||||
is_action_record: bool = field(default=False)
|
is_action_record: bool = field(default=False)
|
||||||
action_name: Optional[str] = None
|
action_name: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||||
|
return cls(
|
||||||
|
chat_id=message.chat_id,
|
||||||
|
time=message.time,
|
||||||
|
user_id=message.user_info.user_id,
|
||||||
|
user_platform=message.user_info.platform,
|
||||||
|
user_nickname=message.user_info.user_nickname,
|
||||||
|
user_cardname=message.user_info.user_cardname,
|
||||||
|
processed_plain_text=message.processed_plain_text,
|
||||||
|
display_message=message.display_message,
|
||||||
|
chat_info_platform=message.chat_info.platform,
|
||||||
|
)
|
||||||
|
|||||||
@@ -47,10 +47,13 @@ logger = get_logger("Gemini客户端")
|
|||||||
# gemini_thinking参数(默认范围)
|
# gemini_thinking参数(默认范围)
|
||||||
# 不同模型的思考预算范围配置
|
# 不同模型的思考预算范围配置
|
||||||
THINKING_BUDGET_LIMITS = {
|
THINKING_BUDGET_LIMITS = {
|
||||||
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
|
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
|
||||||
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
|
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
|
||||||
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
|
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
|
||||||
}
|
}
|
||||||
|
# 思维预算特殊值
|
||||||
|
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
|
||||||
|
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
|
||||||
|
|
||||||
gemini_safe_settings = [
|
gemini_safe_settings = [
|
||||||
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
|
||||||
@@ -91,9 +94,7 @@ def _convert_messages(
|
|||||||
for item in message.content:
|
for item in message.content:
|
||||||
if isinstance(item, tuple):
|
if isinstance(item, tuple):
|
||||||
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
|
||||||
content.append(
|
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
|
||||||
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")
|
|
||||||
)
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
content.append(Part.from_text(text=item))
|
content.append(Part.from_text(text=item))
|
||||||
else:
|
else:
|
||||||
@@ -336,47 +337,40 @@ class GeminiClient(BaseClient):
|
|||||||
api_key=api_provider.api_key,
|
api_key=api_provider.api_key,
|
||||||
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
) # 这里和openai不一样,gemini会自己决定自己是否需要retry
|
||||||
|
|
||||||
# 思维预算特殊值
|
|
||||||
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
|
|
||||||
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clamp_thinking_budget(tb: int, model_id: str):
|
def clamp_thinking_budget(tb: int, model_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
按模型限制思考预算范围,仅支持指定的模型(支持带数字后缀的新版本)
|
||||||
"""
|
"""
|
||||||
limits = None
|
limits = None
|
||||||
matched_key = None
|
|
||||||
|
|
||||||
# 优先尝试精确匹配
|
# 优先尝试精确匹配
|
||||||
if model_id in THINKING_BUDGET_LIMITS:
|
if model_id in THINKING_BUDGET_LIMITS:
|
||||||
limits = THINKING_BUDGET_LIMITS[model_id]
|
limits = THINKING_BUDGET_LIMITS[model_id]
|
||||||
matched_key = model_id
|
|
||||||
else:
|
else:
|
||||||
# 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先
|
# 按 key 长度倒序,保证更长的(更具体的,如 -lite)优先
|
||||||
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
|
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
|
||||||
for key in sorted_keys:
|
for key in sorted_keys:
|
||||||
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
|
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
|
||||||
if model_id == key or model_id.startswith(key + "-"):
|
if model_id == key or model_id.startswith(f"{key}-"):
|
||||||
limits = THINKING_BUDGET_LIMITS[key]
|
limits = THINKING_BUDGET_LIMITS[key]
|
||||||
matched_key = key
|
break
|
||||||
break
|
|
||||||
|
|
||||||
# 特殊值处理
|
# 特殊值处理
|
||||||
if tb == GeminiClient.THINKING_BUDGET_AUTO:
|
if tb == THINKING_BUDGET_AUTO:
|
||||||
return GeminiClient.THINKING_BUDGET_AUTO
|
return THINKING_BUDGET_AUTO
|
||||||
if tb == GeminiClient.THINKING_BUDGET_DISABLED:
|
if tb == THINKING_BUDGET_DISABLED:
|
||||||
if limits and limits.get("can_disable", False):
|
if limits and limits.get("can_disable", False):
|
||||||
return GeminiClient.THINKING_BUDGET_DISABLED
|
return THINKING_BUDGET_DISABLED
|
||||||
return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO
|
return limits["min"] if limits else THINKING_BUDGET_AUTO
|
||||||
|
|
||||||
# 已知模型裁剪到范围
|
# 已知模型裁剪到范围
|
||||||
if limits:
|
if limits:
|
||||||
return max(limits["min"], min(tb, limits["max"]))
|
return max(limits["min"], min(tb, limits["max"]))
|
||||||
|
|
||||||
# 未知模型,返回动态模式
|
# 未知模型,返回动态模式
|
||||||
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
|
||||||
return GeminiClient.THINKING_BUDGET_AUTO
|
return THINKING_BUDGET_AUTO
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
@@ -424,15 +418,13 @@ class GeminiClient(BaseClient):
|
|||||||
# 将tool_options转换为Gemini API所需的格式
|
# 将tool_options转换为Gemini API所需的格式
|
||||||
tools = _convert_tool_options(tool_options) if tool_options else None
|
tools = _convert_tool_options(tool_options) if tool_options else None
|
||||||
|
|
||||||
tb = GeminiClient.THINKING_BUDGET_AUTO
|
tb = THINKING_BUDGET_AUTO
|
||||||
#空处理
|
# 空处理
|
||||||
if extra_params and "thinking_budget" in extra_params:
|
if extra_params and "thinking_budget" in extra_params:
|
||||||
try:
|
try:
|
||||||
tb = int(extra_params["thinking_budget"])
|
tb = int(extra_params["thinking_budget"])
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
logger.warning(
|
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
|
||||||
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}"
|
|
||||||
)
|
|
||||||
# 裁剪到模型支持的范围
|
# 裁剪到模型支持的范围
|
||||||
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from src.common.logger import get_logger
|
|||||||
from src.individuality.individuality import get_individuality, Individuality
|
from src.individuality.individuality import get_individuality, Individuality
|
||||||
from src.common.server import get_global_server, Server
|
from src.common.server import get_global_server, Server
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
from src.chat.knowledge import lpmm_start_up
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.migrate_helper.migrate import check_and_run_migrations
|
from src.migrate_helper.migrate import check_and_run_migrations
|
||||||
# from src.api.main import start_api_server
|
# from src.api.main import start_api_server
|
||||||
@@ -83,6 +84,9 @@ class MainSystem:
|
|||||||
# 启动API服务器
|
# 启动API服务器
|
||||||
# start_api_server()
|
# start_api_server()
|
||||||
# logger.info("API服务器启动成功")
|
# logger.info("API服务器启动成功")
|
||||||
|
|
||||||
|
# 启动LPMM
|
||||||
|
lpmm_start_up()
|
||||||
|
|
||||||
# 加载所有actions,包括默认的和插件的
|
# 加载所有actions,包括默认的和插件的
|
||||||
plugin_manager.load_all_plugins()
|
plugin_manager.load_all_plugins()
|
||||||
@@ -104,7 +108,6 @@ class MainSystem:
|
|||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
|
|
||||||
# 初始化聊天管理器
|
# 初始化聊天管理器
|
||||||
|
|
||||||
await get_chat_manager()._initialize()
|
await get_chat_manager()._initialize()
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
import math
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
@@ -16,6 +17,7 @@ from src.config.config import global_config, model_config
|
|||||||
|
|
||||||
logger = get_logger("person_info")
|
logger = get_logger("person_info")
|
||||||
|
|
||||||
|
|
||||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||||
"""获取唯一id"""
|
"""获取唯一id"""
|
||||||
if "-" in platform:
|
if "-" in platform:
|
||||||
@@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
|||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_person_id_by_person_name(person_name: str) -> str:
|
def get_person_id_by_person_name(person_name: str) -> str:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
try:
|
try:
|
||||||
@@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str:
|
|||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
|
|
||||||
|
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
|
||||||
if person_id:
|
if person_id:
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
@@ -47,89 +51,84 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No
|
|||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_catagory_from_memory(memory_point:str) -> str:
|
def get_category_from_memory(memory_point: str) -> Optional[str]:
|
||||||
"""从记忆点中获取分类"""
|
"""从记忆点中获取分类"""
|
||||||
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
|
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
|
||||||
if not isinstance(memory_point, str):
|
if not isinstance(memory_point, str):
|
||||||
return None
|
return None
|
||||||
parts = memory_point.split(":", 1)
|
parts = memory_point.split(":", 1)
|
||||||
if len(parts) > 1:
|
return parts[0].strip() if len(parts) > 1 else None
|
||||||
return parts[0].strip()
|
|
||||||
else:
|
|
||||||
return None
|
def get_weight_from_memory(memory_point: str) -> float:
|
||||||
|
|
||||||
def get_weight_from_memory(memory_point:str) -> float:
|
|
||||||
"""从记忆点中获取权重"""
|
"""从记忆点中获取权重"""
|
||||||
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
|
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
|
||||||
if not isinstance(memory_point, str):
|
if not isinstance(memory_point, str):
|
||||||
return None
|
return -math.inf
|
||||||
parts = memory_point.rsplit(":", 1)
|
parts = memory_point.rsplit(":", 1)
|
||||||
if len(parts) > 1:
|
if len(parts) <= 1:
|
||||||
try:
|
return -math.inf
|
||||||
return float(parts[-1].strip())
|
try:
|
||||||
except Exception:
|
return float(parts[-1].strip())
|
||||||
return None
|
except Exception:
|
||||||
else:
|
return -math.inf
|
||||||
return None
|
|
||||||
|
|
||||||
def get_memory_content_from_memory(memory_point:str) -> str:
|
def get_memory_content_from_memory(memory_point: str) -> str:
|
||||||
"""从记忆点中获取记忆内容"""
|
"""从记忆点中获取记忆内容"""
|
||||||
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
|
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
|
||||||
if not isinstance(memory_point, str):
|
if not isinstance(memory_point, str):
|
||||||
return None
|
return ""
|
||||||
parts = memory_point.split(":")
|
parts = memory_point.split(":")
|
||||||
if len(parts) > 2:
|
return ":".join(parts[1:-1]).strip() if len(parts) > 2 else ""
|
||||||
return ":".join(parts[1:-1]).strip()
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_string_similarity(s1: str, s2: str) -> float:
|
def calculate_string_similarity(s1: str, s2: str) -> float:
|
||||||
"""
|
"""
|
||||||
计算两个字符串的相似度
|
计算两个字符串的相似度
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s1: 第一个字符串
|
s1: 第一个字符串
|
||||||
s2: 第二个字符串
|
s2: 第二个字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: 相似度,范围0-1,1表示完全相同
|
float: 相似度,范围0-1,1表示完全相同
|
||||||
"""
|
"""
|
||||||
if s1 == s2:
|
if s1 == s2:
|
||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
if not s1 or not s2:
|
if not s1 or not s2:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# 计算Levenshtein距离
|
# 计算Levenshtein距离
|
||||||
|
|
||||||
|
|
||||||
distance = levenshtein_distance(s1, s2)
|
distance = levenshtein_distance(s1, s2)
|
||||||
max_len = max(len(s1), len(s2))
|
max_len = max(len(s1), len(s2))
|
||||||
|
|
||||||
# 计算相似度:1 - (编辑距离 / 最大长度)
|
# 计算相似度:1 - (编辑距离 / 最大长度)
|
||||||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||||
return similarity
|
return similarity
|
||||||
|
|
||||||
|
|
||||||
def levenshtein_distance(s1: str, s2: str) -> int:
|
def levenshtein_distance(s1: str, s2: str) -> int:
|
||||||
"""
|
"""
|
||||||
计算两个字符串的编辑距离
|
计算两个字符串的编辑距离
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
s1: 第一个字符串
|
s1: 第一个字符串
|
||||||
s2: 第二个字符串
|
s2: 第二个字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 编辑距离
|
int: 编辑距离
|
||||||
"""
|
"""
|
||||||
if len(s1) < len(s2):
|
if len(s1) < len(s2):
|
||||||
return levenshtein_distance(s2, s1)
|
return levenshtein_distance(s2, s1)
|
||||||
|
|
||||||
if len(s2) == 0:
|
if len(s2) == 0:
|
||||||
return len(s1)
|
return len(s1)
|
||||||
|
|
||||||
previous_row = range(len(s2) + 1)
|
previous_row = range(len(s2) + 1)
|
||||||
for i, c1 in enumerate(s1):
|
for i, c1 in enumerate(s1):
|
||||||
current_row = [i + 1]
|
current_row = [i + 1]
|
||||||
@@ -139,44 +138,45 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
|||||||
substitutions = previous_row[j] + (c1 != c2)
|
substitutions = previous_row[j] + (c1 != c2)
|
||||||
current_row.append(min(insertions, deletions, substitutions))
|
current_row.append(min(insertions, deletions, substitutions))
|
||||||
previous_row = current_row
|
previous_row = current_row
|
||||||
|
|
||||||
return previous_row[-1]
|
return previous_row[-1]
|
||||||
|
|
||||||
|
|
||||||
class Person:
|
class Person:
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_person(cls, platform: str, user_id: str, nickname: str):
|
def register_person(cls, platform: str, user_id: str, nickname: str):
|
||||||
"""
|
"""
|
||||||
注册新用户的类方法
|
注册新用户的类方法
|
||||||
必须输入 platform、user_id 和 nickname 参数
|
必须输入 platform、user_id 和 nickname 参数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
platform: 平台名称
|
platform: 平台名称
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
nickname: 用户昵称
|
nickname: 用户昵称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Person: 新注册的Person实例
|
Person: 新注册的Person实例
|
||||||
"""
|
"""
|
||||||
if not platform or not user_id or not nickname:
|
if not platform or not user_id or not nickname:
|
||||||
logger.error("注册用户失败:platform、user_id 和 nickname 都是必需参数")
|
logger.error("注册用户失败:platform、user_id 和 nickname 都是必需参数")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 生成唯一的person_id
|
# 生成唯一的person_id
|
||||||
person_id = get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
|
|
||||||
if is_person_known(person_id=person_id):
|
if is_person_known(person_id=person_id):
|
||||||
logger.debug(f"用户 {nickname} 已存在")
|
logger.debug(f"用户 {nickname} 已存在")
|
||||||
return Person(person_id=person_id)
|
return Person(person_id=person_id)
|
||||||
|
|
||||||
# 创建Person实例
|
# 创建Person实例
|
||||||
person = cls.__new__(cls)
|
person = cls.__new__(cls)
|
||||||
|
|
||||||
# 设置基本属性
|
# 设置基本属性
|
||||||
person.person_id = person_id
|
person.person_id = person_id
|
||||||
person.platform = platform
|
person.platform = platform
|
||||||
person.user_id = user_id
|
person.user_id = user_id
|
||||||
person.nickname = nickname
|
person.nickname = nickname
|
||||||
|
|
||||||
# 初始化默认值
|
# 初始化默认值
|
||||||
person.is_known = True # 注册后立即标记为已认识
|
person.is_known = True # 注册后立即标记为已认识
|
||||||
person.person_name = nickname # 使用nickname作为初始person_name
|
person.person_name = nickname # 使用nickname作为初始person_name
|
||||||
@@ -185,34 +185,34 @@ class Person:
|
|||||||
person.know_since = time.time()
|
person.know_since = time.time()
|
||||||
person.last_know = time.time()
|
person.last_know = time.time()
|
||||||
person.memory_points = []
|
person.memory_points = []
|
||||||
|
|
||||||
# 初始化性格特征相关字段
|
# 初始化性格特征相关字段
|
||||||
person.attitude_to_me = 0
|
person.attitude_to_me = 0
|
||||||
person.attitude_to_me_confidence = 1
|
person.attitude_to_me_confidence = 1
|
||||||
|
|
||||||
person.neuroticism = 5
|
person.neuroticism = 5
|
||||||
person.neuroticism_confidence = 1
|
person.neuroticism_confidence = 1
|
||||||
|
|
||||||
person.friendly_value = 50
|
person.friendly_value = 50
|
||||||
person.friendly_value_confidence = 1
|
person.friendly_value_confidence = 1
|
||||||
|
|
||||||
person.rudeness = 50
|
person.rudeness = 50
|
||||||
person.rudeness_confidence = 1
|
person.rudeness_confidence = 1
|
||||||
|
|
||||||
person.conscientiousness = 50
|
person.conscientiousness = 50
|
||||||
person.conscientiousness_confidence = 1
|
person.conscientiousness_confidence = 1
|
||||||
|
|
||||||
person.likeness = 50
|
person.likeness = 50
|
||||||
person.likeness_confidence = 1
|
person.likeness_confidence = 1
|
||||||
|
|
||||||
# 同步到数据库
|
# 同步到数据库
|
||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
|
|
||||||
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
|
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
|
||||||
|
|
||||||
return person
|
return person
|
||||||
|
|
||||||
def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""):
|
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
|
||||||
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
|
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
|
||||||
self.is_known = True
|
self.is_known = True
|
||||||
self.person_id = get_person_id(platform, user_id)
|
self.person_id = get_person_id(platform, user_id)
|
||||||
@@ -221,10 +221,10 @@ class Person:
|
|||||||
self.nickname = global_config.bot.nickname
|
self.nickname = global_config.bot.nickname
|
||||||
self.person_name = global_config.bot.nickname
|
self.person_name = global_config.bot.nickname
|
||||||
return
|
return
|
||||||
|
|
||||||
self.user_id = ""
|
self.user_id = ""
|
||||||
self.platform = ""
|
self.platform = ""
|
||||||
|
|
||||||
if person_id:
|
if person_id:
|
||||||
self.person_id = person_id
|
self.person_id = person_id
|
||||||
elif person_name:
|
elif person_name:
|
||||||
@@ -232,7 +232,7 @@ class Person:
|
|||||||
if not self.person_id:
|
if not self.person_id:
|
||||||
self.is_known = False
|
self.is_known = False
|
||||||
logger.warning(f"根据用户名 {person_name} 获取用户ID时,不存在用户{person_name}")
|
logger.warning(f"根据用户名 {person_name} 获取用户ID时,不存在用户{person_name}")
|
||||||
return
|
return
|
||||||
elif platform and user_id:
|
elif platform and user_id:
|
||||||
self.person_id = get_person_id(platform, user_id)
|
self.person_id = get_person_id(platform, user_id)
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
@@ -240,17 +240,16 @@ class Person:
|
|||||||
else:
|
else:
|
||||||
logger.error("Person 初始化失败,缺少必要参数")
|
logger.error("Person 初始化失败,缺少必要参数")
|
||||||
raise ValueError("Person 初始化失败,缺少必要参数")
|
raise ValueError("Person 初始化失败,缺少必要参数")
|
||||||
|
|
||||||
if not is_person_known(person_id=self.person_id):
|
if not is_person_known(person_id=self.person_id):
|
||||||
self.is_known = False
|
self.is_known = False
|
||||||
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||||
self.person_name = f"未知用户{self.person_id[:4]}"
|
self.person_name = f"未知用户{self.person_id[:4]}"
|
||||||
return
|
return
|
||||||
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||||
|
|
||||||
|
|
||||||
self.is_known = False
|
self.is_known = False
|
||||||
|
|
||||||
# 初始化默认值
|
# 初始化默认值
|
||||||
self.nickname = ""
|
self.nickname = ""
|
||||||
self.person_name: Optional[str] = None
|
self.person_name: Optional[str] = None
|
||||||
@@ -259,47 +258,47 @@ class Person:
|
|||||||
self.know_since = None
|
self.know_since = None
|
||||||
self.last_know = None
|
self.last_know = None
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
|
|
||||||
# 初始化性格特征相关字段
|
# 初始化性格特征相关字段
|
||||||
self.attitude_to_me:float = 0
|
self.attitude_to_me: float = 0
|
||||||
self.attitude_to_me_confidence:float = 1
|
self.attitude_to_me_confidence: float = 1
|
||||||
|
|
||||||
self.neuroticism:float = 5
|
self.neuroticism: float = 5
|
||||||
self.neuroticism_confidence:float = 1
|
self.neuroticism_confidence: float = 1
|
||||||
|
|
||||||
self.friendly_value:float = 50
|
self.friendly_value: float = 50
|
||||||
self.friendly_value_confidence:float = 1
|
self.friendly_value_confidence: float = 1
|
||||||
|
|
||||||
self.rudeness:float = 50
|
self.rudeness: float = 50
|
||||||
self.rudeness_confidence:float = 1
|
self.rudeness_confidence: float = 1
|
||||||
|
|
||||||
self.conscientiousness:float = 50
|
self.conscientiousness: float = 50
|
||||||
self.conscientiousness_confidence:float = 1
|
self.conscientiousness_confidence: float = 1
|
||||||
|
|
||||||
self.likeness:float = 50
|
self.likeness: float = 50
|
||||||
self.likeness_confidence:float = 1
|
self.likeness_confidence: float = 1
|
||||||
|
|
||||||
# 从数据库加载数据
|
# 从数据库加载数据
|
||||||
self.load_from_database()
|
self.load_from_database()
|
||||||
|
|
||||||
def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95):
|
def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95):
|
||||||
"""
|
"""
|
||||||
删除指定分类和记忆内容的记忆点
|
删除指定分类和记忆内容的记忆点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
category: 记忆分类
|
category: 记忆分类
|
||||||
memory_content: 要删除的记忆内容
|
memory_content: 要删除的记忆内容
|
||||||
similarity_threshold: 相似度阈值,默认0.95(95%)
|
similarity_threshold: 相似度阈值,默认0.95(95%)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 删除的记忆点数量
|
int: 删除的记忆点数量
|
||||||
"""
|
"""
|
||||||
if not self.memory_points:
|
if not self.memory_points:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
deleted_count = 0
|
deleted_count = 0
|
||||||
memory_points_to_keep = []
|
memory_points_to_keep = []
|
||||||
|
|
||||||
for memory_point in self.memory_points:
|
for memory_point in self.memory_points:
|
||||||
# 跳过None值
|
# 跳过None值
|
||||||
if memory_point is None:
|
if memory_point is None:
|
||||||
@@ -310,80 +309,76 @@ class Person:
|
|||||||
# 格式不正确,保留原样
|
# 格式不正确,保留原样
|
||||||
memory_points_to_keep.append(memory_point)
|
memory_points_to_keep.append(memory_point)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
memory_category = parts[0].strip()
|
memory_category = parts[0].strip()
|
||||||
memory_text = parts[1].strip()
|
memory_text = parts[1].strip()
|
||||||
memory_weight = parts[2].strip()
|
memory_weight = parts[2].strip()
|
||||||
|
|
||||||
# 检查分类是否匹配
|
# 检查分类是否匹配
|
||||||
if memory_category != category:
|
if memory_category != category:
|
||||||
memory_points_to_keep.append(memory_point)
|
memory_points_to_keep.append(memory_point)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算记忆内容的相似度
|
# 计算记忆内容的相似度
|
||||||
similarity = calculate_string_similarity(memory_content, memory_text)
|
similarity = calculate_string_similarity(memory_content, memory_text)
|
||||||
|
|
||||||
# 如果相似度达到阈值,则删除(不添加到保留列表)
|
# 如果相似度达到阈值,则删除(不添加到保留列表)
|
||||||
if similarity >= similarity_threshold:
|
if similarity >= similarity_threshold:
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})")
|
logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})")
|
||||||
else:
|
else:
|
||||||
memory_points_to_keep.append(memory_point)
|
memory_points_to_keep.append(memory_point)
|
||||||
|
|
||||||
# 更新memory_points
|
# 更新memory_points
|
||||||
self.memory_points = memory_points_to_keep
|
self.memory_points = memory_points_to_keep
|
||||||
|
|
||||||
# 同步到数据库
|
# 同步到数据库
|
||||||
if deleted_count > 0:
|
if deleted_count > 0:
|
||||||
self.sync_to_database()
|
self.sync_to_database()
|
||||||
logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}")
|
logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}")
|
||||||
|
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_all_category(self):
|
def get_all_category(self):
|
||||||
category_list = []
|
category_list = []
|
||||||
for memory in self.memory_points:
|
for memory in self.memory_points:
|
||||||
if memory is None:
|
if memory is None:
|
||||||
continue
|
continue
|
||||||
category = get_catagory_from_memory(memory)
|
category = get_category_from_memory(memory)
|
||||||
if category and category not in category_list:
|
if category and category not in category_list:
|
||||||
category_list.append(category)
|
category_list.append(category)
|
||||||
return category_list
|
return category_list
|
||||||
|
|
||||||
|
def get_memory_list_by_category(self, category: str):
|
||||||
def get_memory_list_by_category(self,category:str):
|
|
||||||
memory_list = []
|
memory_list = []
|
||||||
for memory in self.memory_points:
|
for memory in self.memory_points:
|
||||||
if memory is None:
|
if memory is None:
|
||||||
continue
|
continue
|
||||||
if get_catagory_from_memory(memory) == category:
|
if get_category_from_memory(memory) == category:
|
||||||
memory_list.append(memory)
|
memory_list.append(memory)
|
||||||
return memory_list
|
return memory_list
|
||||||
|
|
||||||
def get_random_memory_by_category(self,category:str,num:int=1):
|
def get_random_memory_by_category(self, category: str, num: int = 1):
|
||||||
memory_list = self.get_memory_list_by_category(category)
|
memory_list = self.get_memory_list_by_category(category)
|
||||||
if len(memory_list) < num:
|
if len(memory_list) < num:
|
||||||
return memory_list
|
return memory_list
|
||||||
return random.sample(memory_list, num)
|
return random.sample(memory_list, num)
|
||||||
|
|
||||||
def load_from_database(self):
|
def load_from_database(self):
|
||||||
"""从数据库加载个人信息数据"""
|
"""从数据库加载个人信息数据"""
|
||||||
try:
|
try:
|
||||||
# 查询数据库中的记录
|
# 查询数据库中的记录
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
self.user_id = record.user_id if record.user_id else ""
|
self.user_id = record.user_id or ""
|
||||||
self.platform = record.platform if record.platform else ""
|
self.platform = record.platform or ""
|
||||||
self.is_known = record.is_known if record.is_known else False
|
self.is_known = record.is_known or False
|
||||||
self.nickname = record.nickname if record.nickname else ""
|
self.nickname = record.nickname or ""
|
||||||
self.person_name = record.person_name if record.person_name else self.nickname
|
self.person_name = record.person_name or self.nickname
|
||||||
self.name_reason = record.name_reason if record.name_reason else None
|
self.name_reason = record.name_reason or None
|
||||||
self.know_times = record.know_times if record.know_times else 0
|
self.know_times = record.know_times or 0
|
||||||
|
|
||||||
# 处理points字段(JSON格式的列表)
|
# 处理points字段(JSON格式的列表)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
try:
|
try:
|
||||||
@@ -398,53 +393,53 @@ class Person:
|
|||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
else:
|
else:
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
|
|
||||||
# 加载性格特征相关字段
|
# 加载性格特征相关字段
|
||||||
if record.attitude_to_me and not isinstance(record.attitude_to_me, str):
|
if record.attitude_to_me and not isinstance(record.attitude_to_me, str):
|
||||||
self.attitude_to_me = record.attitude_to_me
|
self.attitude_to_me = record.attitude_to_me
|
||||||
|
|
||||||
if record.attitude_to_me_confidence is not None:
|
if record.attitude_to_me_confidence is not None:
|
||||||
self.attitude_to_me_confidence = float(record.attitude_to_me_confidence)
|
self.attitude_to_me_confidence = float(record.attitude_to_me_confidence)
|
||||||
|
|
||||||
if record.friendly_value is not None:
|
if record.friendly_value is not None:
|
||||||
self.friendly_value = float(record.friendly_value)
|
self.friendly_value = float(record.friendly_value)
|
||||||
|
|
||||||
if record.friendly_value_confidence is not None:
|
if record.friendly_value_confidence is not None:
|
||||||
self.friendly_value_confidence = float(record.friendly_value_confidence)
|
self.friendly_value_confidence = float(record.friendly_value_confidence)
|
||||||
|
|
||||||
if record.rudeness is not None:
|
if record.rudeness is not None:
|
||||||
self.rudeness = float(record.rudeness)
|
self.rudeness = float(record.rudeness)
|
||||||
|
|
||||||
if record.rudeness_confidence is not None:
|
if record.rudeness_confidence is not None:
|
||||||
self.rudeness_confidence = float(record.rudeness_confidence)
|
self.rudeness_confidence = float(record.rudeness_confidence)
|
||||||
|
|
||||||
if record.neuroticism and not isinstance(record.neuroticism, str):
|
if record.neuroticism and not isinstance(record.neuroticism, str):
|
||||||
self.neuroticism = float(record.neuroticism)
|
self.neuroticism = float(record.neuroticism)
|
||||||
|
|
||||||
if record.neuroticism_confidence is not None:
|
if record.neuroticism_confidence is not None:
|
||||||
self.neuroticism_confidence = float(record.neuroticism_confidence)
|
self.neuroticism_confidence = float(record.neuroticism_confidence)
|
||||||
|
|
||||||
if record.conscientiousness is not None:
|
if record.conscientiousness is not None:
|
||||||
self.conscientiousness = float(record.conscientiousness)
|
self.conscientiousness = float(record.conscientiousness)
|
||||||
|
|
||||||
if record.conscientiousness_confidence is not None:
|
if record.conscientiousness_confidence is not None:
|
||||||
self.conscientiousness_confidence = float(record.conscientiousness_confidence)
|
self.conscientiousness_confidence = float(record.conscientiousness_confidence)
|
||||||
|
|
||||||
if record.likeness is not None:
|
if record.likeness is not None:
|
||||||
self.likeness = float(record.likeness)
|
self.likeness = float(record.likeness)
|
||||||
|
|
||||||
if record.likeness_confidence is not None:
|
if record.likeness_confidence is not None:
|
||||||
self.likeness_confidence = float(record.likeness_confidence)
|
self.likeness_confidence = float(record.likeness_confidence)
|
||||||
|
|
||||||
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
|
||||||
else:
|
else:
|
||||||
self.sync_to_database()
|
self.sync_to_database()
|
||||||
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
|
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
|
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
|
||||||
# 出错时保持默认值
|
# 出错时保持默认值
|
||||||
|
|
||||||
def sync_to_database(self):
|
def sync_to_database(self):
|
||||||
"""将所有属性同步回数据库"""
|
"""将所有属性同步回数据库"""
|
||||||
if not self.is_known:
|
if not self.is_known:
|
||||||
@@ -452,34 +447,38 @@ class Person:
|
|||||||
try:
|
try:
|
||||||
# 准备数据
|
# 准备数据
|
||||||
data = {
|
data = {
|
||||||
'person_id': self.person_id,
|
"person_id": self.person_id,
|
||||||
'is_known': self.is_known,
|
"is_known": self.is_known,
|
||||||
'platform': self.platform,
|
"platform": self.platform,
|
||||||
'user_id': self.user_id,
|
"user_id": self.user_id,
|
||||||
'nickname': self.nickname,
|
"nickname": self.nickname,
|
||||||
'person_name': self.person_name,
|
"person_name": self.person_name,
|
||||||
'name_reason': self.name_reason,
|
"name_reason": self.name_reason,
|
||||||
'know_times': self.know_times,
|
"know_times": self.know_times,
|
||||||
'know_since': self.know_since,
|
"know_since": self.know_since,
|
||||||
'last_know': self.last_know,
|
"last_know": self.last_know,
|
||||||
'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False),
|
"memory_points": json.dumps(
|
||||||
'attitude_to_me': self.attitude_to_me,
|
[point for point in self.memory_points if point is not None], ensure_ascii=False
|
||||||
'attitude_to_me_confidence': self.attitude_to_me_confidence,
|
)
|
||||||
'friendly_value': self.friendly_value,
|
if self.memory_points
|
||||||
'friendly_value_confidence': self.friendly_value_confidence,
|
else json.dumps([], ensure_ascii=False),
|
||||||
'rudeness': self.rudeness,
|
"attitude_to_me": self.attitude_to_me,
|
||||||
'rudeness_confidence': self.rudeness_confidence,
|
"attitude_to_me_confidence": self.attitude_to_me_confidence,
|
||||||
'neuroticism': self.neuroticism,
|
"friendly_value": self.friendly_value,
|
||||||
'neuroticism_confidence': self.neuroticism_confidence,
|
"friendly_value_confidence": self.friendly_value_confidence,
|
||||||
'conscientiousness': self.conscientiousness,
|
"rudeness": self.rudeness,
|
||||||
'conscientiousness_confidence': self.conscientiousness_confidence,
|
"rudeness_confidence": self.rudeness_confidence,
|
||||||
'likeness': self.likeness,
|
"neuroticism": self.neuroticism,
|
||||||
'likeness_confidence': self.likeness_confidence,
|
"neuroticism_confidence": self.neuroticism_confidence,
|
||||||
|
"conscientiousness": self.conscientiousness,
|
||||||
|
"conscientiousness_confidence": self.conscientiousness_confidence,
|
||||||
|
"likeness": self.likeness,
|
||||||
|
"likeness_confidence": self.likeness_confidence,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查记录是否存在
|
# 检查记录是否存在
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
# 更新现有记录
|
# 更新现有记录
|
||||||
for field, value in data.items():
|
for field, value in data.items():
|
||||||
@@ -491,10 +490,10 @@ class Person:
|
|||||||
# 创建新记录
|
# 创建新记录
|
||||||
PersonInfo.create(**data)
|
PersonInfo.create(**data)
|
||||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
|
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
|
||||||
|
|
||||||
def build_relationship(self):
|
def build_relationship(self):
|
||||||
if not self.is_known:
|
if not self.is_known:
|
||||||
return ""
|
return ""
|
||||||
@@ -505,22 +504,21 @@ class Person:
|
|||||||
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
|
||||||
|
|
||||||
relation_info = ""
|
relation_info = ""
|
||||||
|
|
||||||
attitude_info = ""
|
attitude_info = ""
|
||||||
if self.attitude_to_me:
|
if self.attitude_to_me:
|
||||||
if self.attitude_to_me > 8:
|
if self.attitude_to_me > 8:
|
||||||
attitude_info = f"{self.person_name}对你的态度十分好,"
|
attitude_info = f"{self.person_name}对你的态度十分好,"
|
||||||
elif self.attitude_to_me > 5:
|
elif self.attitude_to_me > 5:
|
||||||
attitude_info = f"{self.person_name}对你的态度较好,"
|
attitude_info = f"{self.person_name}对你的态度较好,"
|
||||||
|
|
||||||
|
|
||||||
if self.attitude_to_me < -8:
|
if self.attitude_to_me < -8:
|
||||||
attitude_info = f"{self.person_name}对你的态度十分恶劣,"
|
attitude_info = f"{self.person_name}对你的态度十分恶劣,"
|
||||||
elif self.attitude_to_me < -4:
|
elif self.attitude_to_me < -4:
|
||||||
attitude_info = f"{self.person_name}对你的态度不好,"
|
attitude_info = f"{self.person_name}对你的态度不好,"
|
||||||
elif self.attitude_to_me < 0:
|
elif self.attitude_to_me < 0:
|
||||||
attitude_info = f"{self.person_name}对你的态度一般,"
|
attitude_info = f"{self.person_name}对你的态度一般,"
|
||||||
|
|
||||||
neuroticism_info = ""
|
neuroticism_info = ""
|
||||||
if self.neuroticism:
|
if self.neuroticism:
|
||||||
if self.neuroticism > 8:
|
if self.neuroticism > 8:
|
||||||
@@ -533,29 +531,28 @@ class Person:
|
|||||||
neuroticism_info = f"{self.person_name}的情绪比较稳定,"
|
neuroticism_info = f"{self.person_name}的情绪比较稳定,"
|
||||||
else:
|
else:
|
||||||
neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动"
|
neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动"
|
||||||
|
|
||||||
points_text = ""
|
points_text = ""
|
||||||
category_list = self.get_all_category()
|
category_list = self.get_all_category()
|
||||||
for category in category_list:
|
for category in category_list:
|
||||||
random_memory = self.get_random_memory_by_category(category,1)[0]
|
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
||||||
if random_memory:
|
if random_memory:
|
||||||
points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}"
|
points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}"
|
||||||
break
|
break
|
||||||
|
|
||||||
points_info = ""
|
points_info = ""
|
||||||
if points_text:
|
if points_text:
|
||||||
points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}"
|
points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}"
|
||||||
|
|
||||||
if not (nickname_str or attitude_info or neuroticism_info or points_info):
|
if not (nickname_str or attitude_info or neuroticism_info or points_info):
|
||||||
return ""
|
return ""
|
||||||
relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}"
|
relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}"
|
||||||
|
|
||||||
return relation_info
|
return relation_info
|
||||||
|
|
||||||
|
|
||||||
class PersonInfoManager:
|
class PersonInfoManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
||||||
self.person_name_list = {}
|
self.person_name_list = {}
|
||||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||||
try:
|
try:
|
||||||
@@ -580,8 +577,6 @@ class PersonInfoManager:
|
|||||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_json_from_text(text: str) -> dict:
|
def _extract_json_from_text(text: str) -> dict:
|
||||||
@@ -717,6 +712,6 @@ class PersonInfoManager:
|
|||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
self.person_name_list[person_id] = unique_nickname
|
self.person_name_list[person_id] = unique_nickname
|
||||||
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
|
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
|
||||||
|
|
||||||
|
|
||||||
person_info_manager = PersonInfoManager()
|
person_info_manager = PersonInfoManager()
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import traceback
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Any
|
import asyncio
|
||||||
|
from typing import List, Dict, Any, TYPE_CHECKING
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
@@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
num_new_messages_since,
|
num_new_messages_since,
|
||||||
)
|
)
|
||||||
import asyncio
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("relationship_builder")
|
logger = get_logger("relationship_builder")
|
||||||
|
|
||||||
@@ -429,7 +432,7 @@ class RelationshipBuilder:
|
|||||||
if dropped_count > 0:
|
if dropped_count > 0:
|
||||||
logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
||||||
|
|
||||||
processed_messages = []
|
processed_messages: List["DatabaseMessages"] = []
|
||||||
|
|
||||||
# 对筛选后的消息段进行排序,确保时间顺序
|
# 对筛选后的消息段进行排序,确保时间顺序
|
||||||
segments_to_process.sort(key=lambda x: x["start_time"])
|
segments_to_process.sort(key=lambda x: x["start_time"])
|
||||||
@@ -449,17 +452,18 @@ class RelationshipBuilder:
|
|||||||
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
||||||
if processed_messages:
|
if processed_messages:
|
||||||
# 创建一个特殊的间隔消息
|
# 创建一个特殊的间隔消息
|
||||||
gap_message = {
|
gap_message = DatabaseMessages(
|
||||||
"time": start_time - 0.1, # 稍微早于段开始时间
|
time=start_time - 0.1,
|
||||||
"user_id": "system",
|
user_id="system",
|
||||||
"user_platform": "system",
|
user_platform="system",
|
||||||
"user_nickname": "系统",
|
user_nickname="系统",
|
||||||
"user_cardname": "",
|
user_cardname="",
|
||||||
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
||||||
"is_action_record": True,
|
is_action_record=True,
|
||||||
"chat_info_platform": segment_messages[0].chat_info.platform or "",
|
chat_info_platform=segment_messages[0].chat_info.platform or "",
|
||||||
"chat_id": chat_id,
|
chat_id=chat_id,
|
||||||
}
|
)
|
||||||
|
|
||||||
processed_messages.append(gap_message)
|
processed_messages.append(gap_message)
|
||||||
|
|
||||||
# 添加该段的所有消息
|
# 添加该段的所有消息
|
||||||
@@ -467,11 +471,11 @@ class RelationshipBuilder:
|
|||||||
|
|
||||||
if processed_messages:
|
if processed_messages:
|
||||||
# 按时间排序所有消息(包括间隔标识)
|
# 按时间排序所有消息(包括间隔标识)
|
||||||
processed_messages.sort(key=lambda x: x["time"])
|
processed_messages.sort(key=lambda x: x.time)
|
||||||
|
|
||||||
logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
|
logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
|
|
||||||
build_frequency = 0.3 * global_config.relationship.relation_frequency
|
build_frequency = 0.3 * global_config.relationship.relation_frequency
|
||||||
if random.random() < build_frequency:
|
if random.random() < build_frequency:
|
||||||
# 调用原有的更新方法
|
# 调用原有的更新方法
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ import traceback
|
|||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List, TYPE_CHECKING
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from .person_info import Person
|
from .person_info import Person
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("relation")
|
logger = get_logger("relation")
|
||||||
|
|
||||||
|
|
||||||
@@ -177,7 +179,7 @@ class RelationshipManager:
|
|||||||
|
|
||||||
return person
|
return person
|
||||||
|
|
||||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
|
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]):
|
||||||
"""更新用户印象
|
"""更新用户印象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -192,8 +194,6 @@ class RelationshipManager:
|
|||||||
# nickname = person.nickname
|
# nickname = person.nickname
|
||||||
know_times: float = person.know_times
|
know_times: float = person.know_times
|
||||||
|
|
||||||
user_messages = bot_engaged_messages
|
|
||||||
|
|
||||||
# 匿名化消息
|
# 匿名化消息
|
||||||
# 创建用户名称映射
|
# 创建用户名称映射
|
||||||
name_mapping = {}
|
name_mapping = {}
|
||||||
@@ -201,13 +201,14 @@ class RelationshipManager:
|
|||||||
user_count = 1
|
user_count = 1
|
||||||
|
|
||||||
# 遍历消息,构建映射
|
# 遍历消息,构建映射
|
||||||
for msg in user_messages:
|
for msg in bot_engaged_messages:
|
||||||
if msg.user_info.user_id == "system":
|
if msg.user_info.user_id == "system":
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
user_id = msg.user_info.user_id
|
user_id = msg.user_info.user_id
|
||||||
platform = msg.chat_info.platform
|
platform = msg.chat_info.platform
|
||||||
assert isinstance(user_id, str) and isinstance(platform, str)
|
assert user_id, "用户ID不能为空"
|
||||||
|
assert platform, "平台不能为空"
|
||||||
msg_person = Person(user_id=user_id, platform=platform)
|
msg_person = Person(user_id=user_id, platform=platform)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -233,7 +234,7 @@ class RelationshipManager:
|
|||||||
current_user = chr(ord(current_user) + 1)
|
current_user = chr(ord(current_user) + 1)
|
||||||
|
|
||||||
readable_messages = build_readable_messages(
|
readable_messages = build_readable_messages(
|
||||||
messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for original_name, mapped_name in name_mapping.items():
|
for original_name, mapped_name in name_mapping.items():
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple, Any, Dict, List, Optional
|
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
@@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response
|
|||||||
from src.chat.replyer.replyer_manager import replyer_manager
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("generator_api")
|
logger = get_logger("generator_api")
|
||||||
@@ -73,11 +77,11 @@ async def generate_reply(
|
|||||||
chat_stream: Optional[ChatStream] = None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
chat_id: Optional[str] = None,
|
chat_id: Optional[str] = None,
|
||||||
action_data: Optional[Dict[str, Any]] = None,
|
action_data: Optional[Dict[str, Any]] = None,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
extra_info: str = "",
|
extra_info: str = "",
|
||||||
reply_reason: str = "",
|
reply_reason: str = "",
|
||||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||||
choosen_actions: Optional[List[Dict[str, Any]]] = None,
|
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
|
||||||
enable_tool: bool = False,
|
enable_tool: bool = False,
|
||||||
enable_splitter: bool = True,
|
enable_splitter: bool = True,
|
||||||
enable_chinese_typo: bool = True,
|
enable_chinese_typo: bool = True,
|
||||||
@@ -85,7 +89,7 @@ async def generate_reply(
|
|||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
return_expressions: bool = False,
|
return_expressions: bool = False,
|
||||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]:
|
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
|
||||||
"""生成回复
|
"""生成回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -96,7 +100,7 @@ async def generate_reply(
|
|||||||
extra_info: 额外信息,用于补充上下文
|
extra_info: 额外信息,用于补充上下文
|
||||||
reply_reason: 回复原因
|
reply_reason: 回复原因
|
||||||
available_actions: 可用动作
|
available_actions: 可用动作
|
||||||
choosen_actions: 已选动作
|
chosen_actions: 已选动作
|
||||||
enable_tool: 是否启用工具调用
|
enable_tool: 是否启用工具调用
|
||||||
enable_splitter: 是否启用消息分割器
|
enable_splitter: 是否启用消息分割器
|
||||||
enable_chinese_typo: 是否启用错字生成器
|
enable_chinese_typo: 是否启用错字生成器
|
||||||
@@ -110,16 +114,14 @@ async def generate_reply(
|
|||||||
try:
|
try:
|
||||||
# 获取回复器
|
# 获取回复器
|
||||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||||
replyer = get_replyer(
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
chat_stream, chat_id, request_type=request_type
|
|
||||||
)
|
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||||
return False, [], None
|
return False, [], None, None
|
||||||
|
|
||||||
if not extra_info and action_data:
|
if not extra_info and action_data:
|
||||||
extra_info = action_data.get("extra_info", "")
|
extra_info = action_data.get("extra_info", "")
|
||||||
|
|
||||||
if not reply_reason and action_data:
|
if not reply_reason and action_data:
|
||||||
reply_reason = action_data.get("reason", "")
|
reply_reason = action_data.get("reason", "")
|
||||||
|
|
||||||
@@ -127,7 +129,7 @@ async def generate_reply(
|
|||||||
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
chosen_actions=choosen_actions,
|
chosen_actions=chosen_actions,
|
||||||
enable_tool=enable_tool,
|
enable_tool=enable_tool,
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
reply_reason=reply_reason,
|
reply_reason=reply_reason,
|
||||||
@@ -136,7 +138,7 @@ async def generate_reply(
|
|||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||||
return False, [], None
|
return False, [], None, None
|
||||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
||||||
if content := llm_response_dict.get("content", ""):
|
if content := llm_response_dict.get("content", ""):
|
||||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
@@ -144,28 +146,34 @@ async def generate_reply(
|
|||||||
reply_set = []
|
reply_set = []
|
||||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||||
|
|
||||||
if return_prompt:
|
# if return_prompt:
|
||||||
if return_expressions:
|
# if return_expressions:
|
||||||
return success, reply_set, (prompt, selected_expressions)
|
# return success, reply_set, prompt, selected_expressions
|
||||||
else:
|
# else:
|
||||||
return success, reply_set, prompt
|
# return success, reply_set, prompt, None
|
||||||
else:
|
# else:
|
||||||
if return_expressions:
|
# if return_expressions:
|
||||||
return success, reply_set, (None, selected_expressions)
|
# return success, reply_set, (None, selected_expressions)
|
||||||
else:
|
# else:
|
||||||
return success, reply_set, None
|
# return success, reply_set, None
|
||||||
|
return (
|
||||||
|
success,
|
||||||
|
reply_set,
|
||||||
|
prompt if return_prompt else None,
|
||||||
|
selected_expressions if return_expressions else None,
|
||||||
|
)
|
||||||
|
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
except UserWarning as uw:
|
except UserWarning as uw:
|
||||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
||||||
return False, [], None
|
return False, [], None, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False, [], None
|
return False, [], None, None
|
||||||
|
|
||||||
|
|
||||||
async def rewrite_reply(
|
async def rewrite_reply(
|
||||||
|
|||||||
@@ -21,15 +21,17 @@
|
|||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Union, Dict, Any, List
|
from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
# 导入依赖
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||||
from maim_message import Seg, UserInfo
|
from maim_message import Seg, UserInfo
|
||||||
from src.config.config import global_config
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("send_api")
|
logger = get_logger("send_api")
|
||||||
|
|
||||||
@@ -46,10 +48,10 @@ async def _send_to_target(
|
|||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
selected_expressions:List[int] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定目标发送消息的内部实现
|
"""向指定目标发送消息的内部实现
|
||||||
|
|
||||||
@@ -70,7 +72,7 @@ async def _send_to_target(
|
|||||||
if set_reply and not reply_message:
|
if set_reply and not reply_message:
|
||||||
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
|
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if show_log:
|
if show_log:
|
||||||
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
||||||
|
|
||||||
@@ -98,13 +100,13 @@ async def _send_to_target(
|
|||||||
message_segment = Seg(type=message_type, data=content) # type: ignore
|
message_segment = Seg(type=message_type, data=content) # type: ignore
|
||||||
|
|
||||||
if reply_message:
|
if reply_message:
|
||||||
anchor_message = message_dict_to_message_recv(reply_message)
|
anchor_message = message_dict_to_message_recv(reply_message.flatten())
|
||||||
if anchor_message:
|
if anchor_message:
|
||||||
anchor_message.update_chat_stream(target_stream)
|
anchor_message.update_chat_stream(target_stream)
|
||||||
assert anchor_message.message_info.user_info, "用户信息缺失"
|
assert anchor_message.message_info.user_info, "用户信息缺失"
|
||||||
reply_to_platform_id = (
|
reply_to_platform_id = (
|
||||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
reply_to_platform_id = ""
|
reply_to_platform_id = ""
|
||||||
anchor_message = None
|
anchor_message = None
|
||||||
@@ -192,12 +194,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
message_recv = MessageRecv(message_dict_recv)
|
message_recv = MessageRecv(message_dict_recv)
|
||||||
|
|
||||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
|
||||||
return message_recv
|
return message_recv
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 公共API函数 - 预定义类型的发送函数
|
# 公共API函数 - 预定义类型的发送函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -208,9 +209,9 @@ async def text_to_stream(
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
selected_expressions:List[int] = None,
|
selected_expressions: Optional[List[int]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送文本消息
|
"""向指定流发送文本消息
|
||||||
|
|
||||||
@@ -237,7 +238,13 @@ async def text_to_stream(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def emoji_to_stream(
|
||||||
|
emoji_base64: str,
|
||||||
|
stream_id: str,
|
||||||
|
storage_message: bool = True,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
) -> bool:
|
||||||
"""向指定流发送表情包
|
"""向指定流发送表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
|
return await _send_to_target(
|
||||||
|
"emoji",
|
||||||
|
emoji_base64,
|
||||||
|
stream_id,
|
||||||
|
"",
|
||||||
|
typing=False,
|
||||||
|
storage_message=storage_message,
|
||||||
|
set_reply=set_reply,
|
||||||
|
reply_message=reply_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def image_to_stream(
|
||||||
|
image_base64: str,
|
||||||
|
stream_id: str,
|
||||||
|
storage_message: bool = True,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
) -> bool:
|
||||||
"""向指定流发送图片
|
"""向指定流发送图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
|
return await _send_to_target(
|
||||||
|
"image",
|
||||||
|
image_base64,
|
||||||
|
stream_id,
|
||||||
|
"",
|
||||||
|
typing=False,
|
||||||
|
storage_message=storage_message,
|
||||||
|
set_reply=set_reply,
|
||||||
|
reply_message=reply_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def command_to_stream(
|
async def command_to_stream(
|
||||||
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
command: Union[str, dict],
|
||||||
|
stream_id: str,
|
||||||
|
storage_message: bool = True,
|
||||||
|
display_message: str = "",
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""向指定流发送命令
|
"""向指定流发送命令
|
||||||
|
|
||||||
@@ -279,7 +315,14 @@ async def command_to_stream(
|
|||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target(
|
return await _send_to_target(
|
||||||
"command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message
|
"command",
|
||||||
|
command,
|
||||||
|
stream_id,
|
||||||
|
display_message,
|
||||||
|
typing=False,
|
||||||
|
storage_message=storage_message,
|
||||||
|
set_reply=set_reply,
|
||||||
|
reply_message=reply_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -289,7 +332,7 @@ async def custom_to_stream(
|
|||||||
stream_id: str,
|
stream_id: str,
|
||||||
display_message: str = "",
|
display_message: str = "",
|
||||||
typing: bool = False,
|
typing: bool = False,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
set_reply: bool = False,
|
set_reply: bool = False,
|
||||||
storage_message: bool = True,
|
storage_message: bool = True,
|
||||||
show_log: bool = True,
|
show_log: bool = True,
|
||||||
|
|||||||
@@ -2,13 +2,15 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Tuple, Optional, Dict, Any
|
from typing import Tuple, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
|
||||||
from src.plugin_system.apis import send_api, database_api, message_api
|
from src.plugin_system.apis import send_api, database_api, message_api
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("base_action")
|
logger = get_logger("base_action")
|
||||||
|
|
||||||
@@ -206,7 +208,11 @@ class BaseAction(ABC):
|
|||||||
return False, f"等待新消息失败: {str(e)}"
|
return False, f"等待新消息失败: {str(e)}"
|
||||||
|
|
||||||
async def send_text(
|
async def send_text(
|
||||||
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False
|
self,
|
||||||
|
content: str,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
typing: bool = False,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送文本消息
|
"""发送文本消息
|
||||||
|
|
||||||
@@ -229,7 +235,9 @@ class BaseAction(ABC):
|
|||||||
typing=typing,
|
typing=typing,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def send_emoji(
|
||||||
|
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
|
||||||
|
) -> bool:
|
||||||
"""发送表情包
|
"""发送表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -242,9 +250,13 @@ class BaseAction(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
|
return await send_api.emoji_to_stream(
|
||||||
|
emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
|
||||||
|
)
|
||||||
|
|
||||||
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def send_image(
|
||||||
|
self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
|
||||||
|
) -> bool:
|
||||||
"""发送图片
|
"""发送图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -257,9 +269,18 @@ class BaseAction(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
|
return await send_api.image_to_stream(
|
||||||
|
image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
|
||||||
|
)
|
||||||
|
|
||||||
async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def send_custom(
|
||||||
|
self,
|
||||||
|
message_type: str,
|
||||||
|
content: str,
|
||||||
|
typing: bool = False,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
) -> bool:
|
||||||
"""发送自定义类型消息
|
"""发送自定义类型消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -308,7 +329,13 @@ class BaseAction(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_command(
|
async def send_command(
|
||||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
command_name: str,
|
||||||
|
args: Optional[dict] = None,
|
||||||
|
display_message: str = "",
|
||||||
|
storage_message: bool = True,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送命令消息
|
"""发送命令消息
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, Tuple, Optional, Any
|
from typing import Dict, Tuple, Optional, TYPE_CHECKING
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("base_command")
|
logger = get_logger("base_command")
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +87,13 @@ class BaseCommand(ABC):
|
|||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
|
async def send_text(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
storage_message: bool = True,
|
||||||
|
) -> bool:
|
||||||
"""发送回复消息
|
"""发送回复消息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -100,10 +109,22 @@ class BaseCommand(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
|
return await send_api.text_to_stream(
|
||||||
|
text=content,
|
||||||
|
stream_id=chat_stream.stream_id,
|
||||||
|
set_reply=set_reply,
|
||||||
|
reply_message=reply_message,
|
||||||
|
storage_message=storage_message,
|
||||||
|
)
|
||||||
|
|
||||||
async def send_type(
|
async def send_type(
|
||||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
message_type: str,
|
||||||
|
content: str,
|
||||||
|
display_message: str = "",
|
||||||
|
typing: bool = False,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送指定类型的回复消息到当前聊天环境
|
"""发送指定类型的回复消息到当前聊天环境
|
||||||
|
|
||||||
@@ -134,7 +155,13 @@ class BaseCommand(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def send_command(
|
async def send_command(
|
||||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
|
self,
|
||||||
|
command_name: str,
|
||||||
|
args: Optional[dict] = None,
|
||||||
|
display_message: str = "",
|
||||||
|
storage_message: bool = True,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""发送命令消息
|
"""发送命令消息
|
||||||
|
|
||||||
@@ -177,7 +204,9 @@ class BaseCommand(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
|
async def send_emoji(
|
||||||
|
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
|
||||||
|
) -> bool:
|
||||||
"""发送表情包
|
"""发送表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -191,9 +220,17 @@ class BaseCommand(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message)
|
return await send_api.emoji_to_stream(
|
||||||
|
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message
|
||||||
|
)
|
||||||
|
|
||||||
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
|
async def send_image(
|
||||||
|
self,
|
||||||
|
image_base64: str,
|
||||||
|
set_reply: bool = False,
|
||||||
|
reply_message: Optional["DatabaseMessages"] = None,
|
||||||
|
storage_message: bool = True,
|
||||||
|
) -> bool:
|
||||||
"""发送图片
|
"""发送图片
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -207,7 +244,13 @@ class BaseCommand(ABC):
|
|||||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
|
return await send_api.image_to_stream(
|
||||||
|
image_base64,
|
||||||
|
chat_stream.stream_id,
|
||||||
|
set_reply=set_reply,
|
||||||
|
reply_message=reply_message,
|
||||||
|
storage_message=storage_message,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_command_info(cls) -> "CommandInfo":
|
def get_command_info(cls) -> "CommandInfo":
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Dict, Any
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge import qa_manager
|
||||||
from src.plugin_system import BaseTool, ToolParamType
|
from src.plugin_system import BaseTool, ToolParamType
|
||||||
|
|
||||||
logger = get_logger("lpmm_get_knowledge_tool")
|
logger = get_logger("lpmm_get_knowledge_tool")
|
||||||
|
|||||||
@@ -1,20 +1,13 @@
|
|||||||
import random
|
import json
|
||||||
|
from json_repair import repair_json
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
# 导入新插件系统
|
|
||||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
|
||||||
|
|
||||||
# 导入依赖的系统组件
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
# 导入API模块 - 标准Python包方式
|
|
||||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
|
||||||
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
|
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
import json
|
from src.plugin_system import BaseAction, ActionActivationType
|
||||||
from json_repair import repair_json
|
from src.plugin_system.apis import llm_api
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("relation")
|
logger = get_logger("relation")
|
||||||
@@ -39,10 +32,9 @@ def init_prompt():
|
|||||||
{{
|
{{
|
||||||
"category": "分类名称"
|
"category": "分类名称"
|
||||||
}} """,
|
}} """,
|
||||||
"relation_category"
|
"relation_category",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
以下是有关{category}的现有记忆:
|
以下是有关{category}的现有记忆:
|
||||||
@@ -73,7 +65,7 @@ def init_prompt():
|
|||||||
|
|
||||||
现在,请你根据情况选出合适的修改方式,并输出json,不要输出其他内容:
|
现在,请你根据情况选出合适的修改方式,并输出json,不要输出其他内容:
|
||||||
""",
|
""",
|
||||||
"relation_category_update"
|
"relation_category_update",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 动作参数定义
|
# 动作参数定义
|
||||||
action_parameters = {
|
action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
|
||||||
"person_name":"需要了解或记忆的人的名称",
|
|
||||||
"impression":"需要了解的对某人的记忆或印象"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 动作使用场景
|
# 动作使用场景
|
||||||
action_require = [
|
action_require = [
|
||||||
"了解对于某人的记忆,并添加到你对对方的印象中",
|
"了解对于某人的记忆,并添加到你对对方的印象中",
|
||||||
"对方与有明确提到有关其自身的事件",
|
"对方与有明确提到有关其自身的事件",
|
||||||
"对方有提到其个人信息,包括喜好,身份,等等",
|
"对方有提到其个人信息,包括喜好,身份,等等",
|
||||||
"对方希望你记住对方的信息"
|
"对方希望你记住对方的信息",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 关联类型
|
# 关联类型
|
||||||
@@ -129,9 +118,7 @@ class BuildRelationAction(BaseAction):
|
|||||||
if not person.is_known:
|
if not person.is_known:
|
||||||
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
|
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
|
||||||
return False, f"用户 {person_name} 不存在,跳过添加记忆"
|
return False, f"用户 {person_name} 不存在,跳过添加记忆"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
category_list = person.get_all_category()
|
category_list = person.get_all_category()
|
||||||
if not category_list:
|
if not category_list:
|
||||||
category_list_str = "无分类"
|
category_list_str = "无分类"
|
||||||
@@ -142,9 +129,8 @@ class BuildRelationAction(BaseAction):
|
|||||||
"relation_category",
|
"relation_category",
|
||||||
category_list=category_list_str,
|
category_list=category_list_str,
|
||||||
memory_point=impression,
|
memory_point=impression,
|
||||||
person_name=person.person_name
|
person_name=person.person_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
@@ -161,84 +147,76 @@ class BuildRelationAction(BaseAction):
|
|||||||
success, category, _, _ = await llm_api.generate_with_model(
|
success, category, _, _ = await llm_api.generate_with_model(
|
||||||
prompt, model_config=chat_model_config, request_type="relation.category"
|
prompt, model_config=chat_model_config, request_type="relation.category"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
category_data = json.loads(repair_json(category))
|
category_data = json.loads(repair_json(category))
|
||||||
category = category_data.get("category", "")
|
category = category_data.get("category", "")
|
||||||
if not category:
|
if not category:
|
||||||
logger.warning(f"{self.log_prefix} LLM未给出分类,跳过添加记忆")
|
logger.warning(f"{self.log_prefix} LLM未给出分类,跳过添加记忆")
|
||||||
return False, "LLM未给出分类,跳过添加记忆"
|
return False, "LLM未给出分类,跳过添加记忆"
|
||||||
|
|
||||||
|
|
||||||
# 第二部分:更新记忆
|
# 第二部分:更新记忆
|
||||||
|
|
||||||
memory_list = person.get_memory_list_by_category(category)
|
memory_list = person.get_memory_list_by_category(category)
|
||||||
if not memory_list:
|
if not memory_list:
|
||||||
logger.info(f"{self.log_prefix} {person.person_name} 的 {category} 的记忆为空,进行创建")
|
logger.info(f"{self.log_prefix} {person.person_name} 的 {category} 的记忆为空,进行创建")
|
||||||
person.memory_points.append(f"{category}:{impression}:1.0")
|
person.memory_points.append(f"{category}:{impression}:1.0")
|
||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
|
|
||||||
return True, f"未找到分类为{category}的记忆点,进行添加"
|
return True, f"未找到分类为{category}的记忆点,进行添加"
|
||||||
|
|
||||||
memory_list_str = ""
|
memory_list_str = ""
|
||||||
memory_list_id = {}
|
memory_list_id = {}
|
||||||
id = 1
|
for id, memory in enumerate(memory_list, start=1):
|
||||||
for memory in memory_list:
|
|
||||||
memory_content = get_memory_content_from_memory(memory)
|
memory_content = get_memory_content_from_memory(memory)
|
||||||
memory_list_str += f"{id}. {memory_content}\n"
|
memory_list_str += f"{id}. {memory_content}\n"
|
||||||
memory_list_id[id] = memory
|
memory_list_id[id] = memory
|
||||||
id += 1
|
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
"relation_category_update",
|
"relation_category_update",
|
||||||
category=category,
|
category=category,
|
||||||
memory_list=memory_list_str,
|
memory_list=memory_list_str,
|
||||||
memory_point=impression,
|
memory_point=impression,
|
||||||
person_name=person.person_name
|
person_name=person.person_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||||
|
|
||||||
chat_model_config = models.get("utils")
|
chat_model_config = models.get("utils")
|
||||||
success, update_memory, _, _ = await llm_api.generate_with_model(
|
success, update_memory, _, _ = await llm_api.generate_with_model(
|
||||||
prompt, model_config=chat_model_config, request_type="relation.category.update"
|
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
update_memory_data = json.loads(repair_json(update_memory))
|
update_memory_data = json.loads(repair_json(update_memory))
|
||||||
new_memory = update_memory_data.get("new_memory", "")
|
new_memory = update_memory_data.get("new_memory", "")
|
||||||
memory_id = update_memory_data.get("memory_id", "")
|
memory_id = update_memory_data.get("memory_id", "")
|
||||||
integrate_memory = update_memory_data.get("integrate_memory", "")
|
integrate_memory = update_memory_data.get("integrate_memory", "")
|
||||||
|
|
||||||
if new_memory:
|
if new_memory:
|
||||||
# 新记忆
|
# 新记忆
|
||||||
person.memory_points.append(f"{category}:{new_memory}:1.0")
|
person.memory_points.append(f"{category}:{new_memory}:1.0")
|
||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
|
|
||||||
return True, f"为{person.person_name}新增记忆点: {new_memory}"
|
return True, f"为{person.person_name}新增记忆点: {new_memory}"
|
||||||
elif memory_id and integrate_memory:
|
elif memory_id and integrate_memory:
|
||||||
# 现存或冲突记忆
|
# 现存或冲突记忆
|
||||||
memory = memory_list_id[memory_id]
|
memory = memory_list_id[memory_id]
|
||||||
memory_content = get_memory_content_from_memory(memory)
|
memory_content = get_memory_content_from_memory(memory)
|
||||||
del_count = person.del_memory(category,memory_content)
|
del_count = person.del_memory(category, memory_content)
|
||||||
|
|
||||||
if del_count > 0:
|
if del_count > 0:
|
||||||
logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
|
logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
|
||||||
|
|
||||||
memory_weight = get_weight_from_memory(memory)
|
memory_weight = get_weight_from_memory(memory)
|
||||||
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
|
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
|
||||||
person.sync_to_database()
|
person.sync_to_database()
|
||||||
|
|
||||||
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
|
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
|
||||||
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
|
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return True, "关系动作执行成功"
|
return True, "关系动作执行成功"
|
||||||
|
|
||||||
@@ -248,4 +226,4 @@ class BuildRelationAction(BaseAction):
|
|||||||
|
|
||||||
|
|
||||||
# 还缺一个关系的太多遗忘和对应的提取
|
# 还缺一个关系的太多遗忘和对应的提取
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin
|
|||||||
from src.plugin_system.base.base_plugin import BasePlugin
|
from src.plugin_system.base.base_plugin import BasePlugin
|
||||||
from src.plugin_system.base.component_types import ComponentInfo
|
from src.plugin_system.base.component_types import ComponentInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
from typing import Tuple, List, Type
|
from typing import Tuple, List, Type
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "1.3.0"
|
version = "1.3.1"
|
||||||
|
|
||||||
# 配置文件版本号迭代规则同bot_config.toml
|
# 配置文件版本号迭代规则同bot_config.toml
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user