This commit is contained in:
SengokuCola
2025-08-21 12:40:00 +08:00
33 changed files with 878 additions and 906 deletions

View File

@@ -1,6 +1,7 @@
import asyncio
import time
import traceback
import math
import random
from typing import List, Optional, Dict, Any, Tuple
from rich.traceback import install
@@ -8,6 +9,7 @@ from collections import deque
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
@@ -15,21 +17,19 @@ from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.chat_loop.hfc_utils import CycleDetail
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
from src.chat.frequency_control.focus_value_control import focus_value_control
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager
import math
from src.mais4u.s4u_config import s4u_config
# no_action逻辑已集成到heartFC_chat.py中不再需要导入
from src.chat.chat_loop.hfc_utils import send_typing, stop_typing
# 导入记忆系统
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.frequency_control.talk_frequency_control import talk_frequency_control
from src.chat.frequency_control.focus_value_control import focus_value_control
ERROR_LOOP_INFO = {
"loop_plan_info": {
@@ -62,10 +62,7 @@ class HeartFChatting:
其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。
"""
def __init__(
self,
chat_id: str,
):
def __init__(self, chat_id: str):
"""
HeartFChatting 初始化函数
@@ -83,7 +80,7 @@ class HeartFChatting:
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
self.talk_frequency_control = talk_frequency_control.get_talk_frequency_control(self.stream_id)
self.focus_value_control = focus_value_control.get_focus_value_control(self.stream_id)
@@ -104,7 +101,7 @@ class HeartFChatting:
self.plan_timeout_count = 0
self.last_read_time = time.time() - 1
self.focus_energy = 1
self.no_action_consecutive = 0
# 最近三次no_action的新消息兴趣度记录
@@ -166,27 +163,26 @@ class HeartFChatting:
# 获取动作类型,兼容新旧格式
action_type = "未知动作"
if hasattr(self, '_current_cycle_detail') and self._current_cycle_detail:
if hasattr(self, "_current_cycle_detail") and self._current_cycle_detail:
loop_plan_info = self._current_cycle_detail.loop_plan_info
if isinstance(loop_plan_info, dict):
action_result = loop_plan_info.get('action_result', {})
action_result = loop_plan_info.get("action_result", {})
if isinstance(action_result, dict):
# 旧格式action_result是字典
action_type = action_result.get('action_type', '未知动作')
action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
action_type = action_result[0].get('action_type', '未知动作')
action_type = action_result[0].get("action_type", "未知动作")
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get('action_type', '未知动作')
action_type = loop_plan_info[0].get("action_type", "未知动作")
logger.info(
f"{self.log_prefix}{self._current_cycle_detail.cycle_id}次思考,"
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒, " # type: ignore
f"选择动作: {action_type}"
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
f"选择动作: {action_type}" + (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
def _determine_form_type(self) -> None:
"""判断使用哪种形式的no_action"""
# 如果连续no_action次数少于3次使用waiting形式
@@ -195,42 +191,44 @@ class HeartFChatting:
else:
# 计算最近三次记录的兴趣度总和
total_recent_interest = sum(self.recent_interest_records)
# 计算调整后的阈值
adjusted_threshold = 1 / self.talk_frequency_control.get_current_talk_frequency()
logger.info(f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}")
logger.info(
f"{self.log_prefix} 最近三次兴趣度总和: {total_recent_interest:.2f}, 调整后阈值: {adjusted_threshold:.2f}"
)
# 如果兴趣度总和小于阈值进入breaking形式
if total_recent_interest < adjusted_threshold:
logger.info(f"{self.log_prefix} 兴趣度不足,进入休息")
self.focus_energy = random.randint(3, 6)
else:
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
self.focus_energy = 1
async def _should_process_messages(self, new_message: List[Dict[str, Any]]) -> tuple[bool,float]:
self.focus_energy = 1
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
"""
判断是否应该处理消息
Args:
new_message: 新消息列表
mode: 当前聊天模式
Returns:
bool: 是否应该处理消息
"""
new_message_count = len(new_message)
talk_frequency = self.talk_frequency_control.get_current_talk_frequency()
modified_exit_count_threshold = self.focus_energy * 0.5 / talk_frequency
modified_exit_interest_threshold = 1.5 / talk_frequency
total_interest = 0.0
for msg_dict in new_message:
interest_value = msg_dict.get("interest_value")
if interest_value is not None and msg_dict.get("processed_plain_text", ""):
for msg in new_message:
interest_value = msg.interest_value
if interest_value is not None and msg.processed_plain_text:
total_interest += float(interest_value)
if new_message_count >= modified_exit_count_threshold:
self.recent_interest_records.append(total_interest)
logger.info(
@@ -244,9 +242,11 @@ class HeartFChatting:
if new_message_count > 0:
# 只在兴趣值变化时输出log
if not hasattr(self, "_last_accumulated_interest") or total_interest != self._last_accumulated_interest:
logger.info(f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}")
logger.info(
f"{self.log_prefix} 休息中,新消息:{new_message_count}条,累计兴趣值: {total_interest:.2f}, 活跃度: {talk_frequency:.1f}"
)
self._last_accumulated_interest = total_interest
if total_interest >= modified_exit_interest_threshold:
# 记录兴趣度到列表
self.recent_interest_records.append(total_interest)
@@ -261,29 +261,25 @@ class HeartFChatting:
f"{self.log_prefix} 已等待{time.time() - self.last_read_time:.0f}秒,累计{new_message_count}条消息,累计兴趣{total_interest:.1f},继续等待..."
)
await asyncio.sleep(0.5)
return False,0.0
return False, 0.0
async def _loopbody(self):
recent_messages_dict = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
limit = 10,
limit=10,
limit_mode="latest",
filter_mai=True,
filter_command=True,
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_recent_messages_dict = [temporarily_transform_class_to_dict(msg) for msg in recent_messages_dict]
)
# 统一的消息处理逻辑
should_process,interest_value = await self._should_process_messages(temp_recent_messages_dict)
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
if should_process:
self.last_read_time = time.time()
await self._observe(interest_value = interest_value)
await self._observe(interest_value=interest_value)
else:
# Normal模式消息数量不足等待
@@ -298,22 +294,21 @@ class HeartFChatting:
cycle_timers: Dict[str, float],
thinking_id,
actions,
selected_expressions:List[int] = None,
selected_expressions: List[int] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform")
if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform = platform ,user_id = action_message.get("user_id", ""))
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@@ -342,12 +337,10 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _observe(self,interest_value:float = 0.0) -> bool:
async def _observe(self, interest_value: float = 0.0) -> bool:
action_type = "no_action"
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率
# 当interest_value为0时概率接近0使用Focus模式
# 当interest_value很高时概率接近1使用Normal模式
@@ -366,7 +359,9 @@ class HeartFChatting:
# 根据概率决定使用哪种模式
if random.random() < normal_mode_probability:
mode = ChatMode.NORMAL
logger.info(f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability*100:.0f}%概率下选择回复")
logger.info(
f"{self.log_prefix} 有兴趣({interest_value:.2f}),在{normal_mode_probability * 100:.0f}%概率下选择回复"
)
else:
mode = ChatMode.FOCUS
@@ -387,10 +382,9 @@ class HeartFChatting:
await hippocampus_manager.build_memory_for_chat(self.stream_id)
except Exception as e:
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
#如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
# 如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
actions = [
{
"action_type": "no_action",
@@ -420,23 +414,21 @@ class HeartFChatting:
):
return False
with Timer("规划器", cycle_timers):
actions, _= await self.action_planner.plan(
actions, _ = await self.action_planner.plan(
mode=mode,
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
# 3. 并行执行所有动作
async def execute_action(action_info,actions):
async def execute_action(action_info, actions):
"""执行单个动作的通用函数"""
try:
if action_info["action_type"] == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_info.get("reasoning", "选择不回复")
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
@@ -447,13 +439,8 @@ class HeartFChatting:
action_data={"reason": reason},
action_name="no_action",
)
return {
"action_type": "no_action",
"success": True,
"reply_text": "",
"command": ""
}
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_info["action_type"] != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
@@ -463,20 +450,19 @@ class HeartFChatting:
action_info["action_data"],
cycle_timers,
thinking_id,
action_info["action_message"]
action_info["action_message"],
)
return {
"action_type": action_info["action_type"],
"success": success,
"reply_text": reply_text,
"command": command
"command": command,
}
else:
try:
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message = action_info["action_message"],
reply_message=action_info["action_message"],
available_actions=available_actions,
choosen_actions=actions,
reply_reason=action_info.get("reasoning", ""),
@@ -485,29 +471,21 @@ class HeartFChatting:
from_plugin=False,
return_expressions=True,
)
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
_,selected_expressions = prompt_selected_expressions
_, selected_expressions = prompt_selected_expressions
else:
selected_expressions = []
if not success or not response_set:
logger.info(f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
logger.info(
f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
)
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None
}
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set=response_set,
@@ -521,7 +499,7 @@ class HeartFChatting:
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
@@ -531,26 +509,26 @@ class HeartFChatting:
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e)
"error": str(e),
}
action_tasks = [asyncio.create_task(execute_action(action,actions)) for action in actions]
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
action_success = False
action_reply_text = ""
action_command = ""
for i, result in enumerate(results):
if isinstance(result, BaseException):
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
_cur_action = actions[i]
if result["action_type"] != "reply":
action_success = result["success"]
@@ -590,7 +568,6 @@ class HeartFChatting:
},
}
reply_text = action_reply_text
if s4u_config.enable_s4u:
await stop_typing()
@@ -602,7 +579,7 @@ class HeartFChatting:
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
action_type = actions[0]["action_type"] if actions else "no_action"
# 管理no_action计数器当执行了非no_action动作时重置计数器
if action_type != "no_action":
# no_action逻辑已集成到heartFC_chat.py中直接重置计数器
@@ -610,7 +587,7 @@ class HeartFChatting:
self.no_action_consecutive = 0
logger.debug(f"{self.log_prefix} 执行了{action_type}动作重置no_action计数器")
return True
if action_type == "no_action":
self.no_action_consecutive += 1
self._determine_form_type()
@@ -692,11 +669,12 @@ class HeartFChatting:
traceback.print_exc()
return False, "", ""
async def _send_response(self,
reply_set,
message_data,
selected_expressions:List[int] = None,
) -> str:
async def _send_response(
self,
reply_set,
message_data,
selected_expressions: List[int] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
@@ -714,7 +692,7 @@ class HeartFChatting:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message = message_data,
reply_message=message_data,
set_reply=need_reply,
typing=False,
selected_expressions=selected_expressions,
@@ -724,7 +702,7 @@ class HeartFChatting:
await send_api.text_to_stream(
text=data,
stream_id=self.chat_stream.stream_id,
reply_message = message_data,
reply_message=message_data,
set_reply=False,
typing=True,
selected_expressions=selected_expressions,

View File

@@ -709,36 +709,36 @@ class EmojiManager:
return emoji
return None # 如果循环结束还没找到,则返回 None
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
async def get_emoji_tag_by_hash(self, emoji_hash: str) -> Optional[List[str]]:
"""根据哈希值获取已注册表情包的情感标签列表
Args:
emoji_hash: 表情包的哈希值
Returns:
Optional[str]: 表情包描述如果未找到则返回None
Optional[List[str]]: 情感标签列如果未找到则返回None
"""
try:
# 先从内存中查找
emoji = await self.get_emoji_from_manager(emoji_hash)
if emoji and emoji.emotion:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.emotion}...")
return ",".join(emoji.emotion)
logger.info(f"[缓存命中] 从内存获取表情包情感标签: {emoji.emotion}...")
return emoji.emotion
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
return emoji_record.emotion
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(',')
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
return None
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
logger.error(f"获取表情包情感标签失败 (Hash: {emoji_hash}): {str(e)}")
return None
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:

View File

@@ -8,6 +8,7 @@ from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
@@ -346,21 +347,17 @@ class ExpressionLearner:
current_time = time.time()
# 获取上次学习时间
random_msg_temp = get_raw_msg_by_timestamp_with_chat_inclusive(
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
random_msg: Optional[List[Dict[str, Any]]] = [temporarily_transform_class_to_dict(msg) for msg in random_msg_temp] if random_msg_temp else None
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
chat_id: str = random_msg[0]["chat_id"]
chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
random_msg_str: str = await build_anonymous_messages(random_msg)
# print(f"random_msg_str:{random_msg_str}")

View File

@@ -117,30 +117,36 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,处理异步调用"""
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 尝试获取当前事件循环
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
def run_in_thread():
return asyncio.run(get_embedding(s))
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
@@ -181,8 +187,14 @@ class EmbeddingStore:
for i, s in enumerate(chunk_strs):
try:
# 直接使用异步函数
embedding = asyncio.run(llm.get_embedding(s))
# 在线程中创建独立的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:

View File

@@ -9,7 +9,6 @@ import networkx as nx
import numpy as np
from typing import List, Tuple, Set, Coroutine, Any, Dict
from collections import Counter
from itertools import combinations
import traceback
from rich.traceback import install
@@ -23,6 +22,8 @@ from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp_with_chat_inclusive,
) # 导入 build_readable_messages
# 添加cosine_similarity函数
def cosine_similarity(v1, v2):
"""计算余弦相似度"""
@@ -51,18 +52,9 @@ def calculate_information_content(text):
return entropy
logger = get_logger("memory")
class MemoryGraph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -96,7 +88,7 @@ class MemoryGraph:
if "memory_items" in self.G.nodes[concept]:
# 获取现有的记忆项已经是str格式
existing_memory = self.G.nodes[concept]["memory_items"]
# 如果现有记忆不为空则使用LLM整合新旧记忆
if existing_memory and hippocampus_instance and hippocampus_instance.model_small:
try:
@@ -170,16 +162,16 @@ class MemoryGraph:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
async def _integrate_memories_with_llm(self, existing_memory: str, new_memory: str, llm_model: LLMRequest) -> str:
"""
使用LLM整合新旧记忆内容
Args:
existing_memory: 现有的记忆内容(字符串格式,可能包含多条记忆)
new_memory: 新的记忆内容
llm_model: LLM模型实例
Returns:
str: 整合后的记忆内容
"""
@@ -203,8 +195,10 @@ class MemoryGraph:
整合后的记忆:"""
# 调用LLM进行整合
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(integration_prompt)
content, (reasoning_content, model_name, tool_calls) = await llm_model.generate_response_async(
integration_prompt
)
if content and content.strip():
integrated_content = content.strip()
logger.debug(f"LLM记忆整合成功模型: {model_name}")
@@ -212,7 +206,7 @@ class MemoryGraph:
else:
logger.warning("LLM返回的整合结果为空使用默认连接方式")
return f"{existing_memory} | {new_memory}"
except Exception as e:
logger.error(f"LLM记忆整合过程中出错: {e}")
return f"{existing_memory} | {new_memory}"
@@ -238,7 +232,11 @@ class MemoryGraph:
if memory_items:
# 删除整个节点
self.G.remove_node(topic)
return f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." if len(memory_items) > 50 else f"删除了节点 {topic} 的完整记忆: {memory_items}"
return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
)
else:
# 如果没有记忆项,删除该节点
self.G.remove_node(topic)
@@ -263,38 +261,40 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
self.model_small = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="memory.modify")
self.model_small = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="memory.modify"
)
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
return list(self.memory_graph.G.nodes())
def calculate_weighted_activation(self, current_activation: float, edge_strength: int, target_node: str) -> float:
"""
计算考虑节点权重的激活值
Args:
current_activation: 当前激活值
edge_strength: 边的强度
target_node: 目标节点名称
Returns:
float: 计算后的激活值
"""
# 基础激活值计算
base_activation = current_activation - (1 / edge_strength)
if base_activation <= 0:
return 0.0
# 获取目标节点的权重
if target_node in self.memory_graph.G:
node_data = self.memory_graph.G.nodes[target_node]
node_weight = node_data.get("weight", 1.0)
# 权重加成每次整合增加10%激活值最大加成200%
weight_multiplier = 1.0 + min((node_weight - 1.0) * 0.1, 2.0)
return base_activation * weight_multiplier
else:
return base_activation
@@ -332,9 +332,7 @@ class Hippocampus:
f"将主题用逗号隔开,并加上<>,例如<主题1>,<主题2>......尽可能精简。只需要列举最多{topic_num}个话题就好,不要有序号,不要告诉我其他内容。"
f"如果确定找不出主题或者没有明显主题,返回<none>。"
)
return prompt
@staticmethod
@@ -403,7 +401,7 @@ class Hippocampus:
memories.sort(key=lambda x: x[2], reverse=True)
return memories
async def get_keywords_from_text(self, text: str) -> list:
async def get_keywords_from_text(self, text: str) -> Tuple[List[str], List]:
"""从文本中提取关键词。
Args:
@@ -418,16 +416,13 @@ class Hippocampus:
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text)
topic_num: int | list[int] = 0
words = jieba.cut(text)
keywords_lite = [word for word in words if len(word) > 1]
keywords_lite = list(set(keywords_lite))
if keywords_lite:
logger.debug(f"提取关键词极简版: {keywords_lite}")
if text_length <= 12:
topic_num = [1, 3] # 6-10字符: 1个关键词 (27.18%的文本)
elif text_length <= 20:
@@ -455,7 +450,7 @@ class Hippocampus:
if keywords:
logger.debug(f"提取关键词: {keywords}")
return keywords,keywords_lite
return keywords, keywords_lite
async def get_memory_from_topic(
self,
@@ -570,20 +565,17 @@ class Hippocampus:
for node, activation in remember_map.items():
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
if memory_items := node_data.get("memory_items", ""):
logger.debug("节点包含完整记忆")
# 计算记忆与关键词的相似度
memory_words = set(jieba.cut(memory_items))
text_words = set(keywords)
all_words = memory_words | text_words
if all_words:
if all_words := memory_words | text_words:
# 计算相似度(虽然这里没有使用,但保持逻辑一致性)
v1 = [1 if word in memory_words else 0 for word in all_words]
v2 = [1 if word in text_words else 0 for word in all_words]
_ = cosine_similarity(v1, v2) # 计算但不使用用_表示
# 添加完整记忆到结果中
all_memories.append((node, memory_items, activation))
else:
@@ -613,7 +605,9 @@ class Hippocampus:
return result
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str],list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str], list[str]]:
"""从文本中提取关键词并获取相关记忆。
Args:
@@ -627,13 +621,13 @@ class Hippocampus:
float: 激活节点数与总节点数的比值
list[str]: 有效的关键词
"""
keywords,keywords_lite = await self.get_keywords_from_text(text)
keywords, keywords_lite = await self.get_keywords_from_text(text)
# 过滤掉不存在于记忆图中的关键词
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
if not valid_keywords:
# logger.info("没有找到有效的关键词节点")
return 0, keywords,keywords_lite
return 0, keywords, keywords_lite
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
@@ -700,7 +694,7 @@ class Hippocampus:
activation_ratio = activation_ratio * 50
logger.debug(f"总激活值: {total_activation:.2f}, 总节点数: {total_nodes}, 激活: {activation_ratio}")
return activation_ratio, keywords,keywords_lite
return activation_ratio, keywords, keywords_lite
# 负责海马体与其他部分的交互
@@ -730,7 +724,7 @@ class EntorhinalCortex:
continue
memory_items = data.get("memory_items", "")
# 直接检查字符串是否为空,不需要分割成列表
if not memory_items or memory_items.strip() == "":
self.memory_graph.G.remove_node(concept)
@@ -865,7 +859,9 @@ class EntorhinalCortex:
end_time = time.time()
logger.info(f"[数据库] 同步完成,总耗时: {end_time - start_time:.2f}")
logger.info(f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边")
logger.info(
f"[数据库] 同步了 {len(nodes_to_create) + len(nodes_to_update)} 个节点和 {len(edges_to_create) + len(edges_to_update)} 条边"
)
async def resync_memory_to_db(self):
"""清空数据库并重新同步所有记忆数据"""
@@ -888,7 +884,7 @@ class EntorhinalCortex:
nodes_data = []
for concept, data in memory_nodes:
memory_items = data.get("memory_items", "")
# 直接检查字符串是否为空,不需要分割成列表
if not memory_items or memory_items.strip() == "":
self.memory_graph.G.remove_node(concept)
@@ -960,7 +956,7 @@ class EntorhinalCortex:
# 清空当前图
self.memory_graph.G.clear()
# 统计加载情况
total_nodes = 0
loaded_nodes = 0
@@ -969,7 +965,7 @@ class EntorhinalCortex:
# 从数据库加载所有节点
nodes = list(GraphNodes.select())
total_nodes = len(nodes)
for node in nodes:
concept = node.concept
try:
@@ -978,7 +974,7 @@ class EntorhinalCortex:
logger.warning(f"节点 {concept} 的memory_items为空跳过")
skipped_nodes += 1
continue
# 直接使用memory_items
memory_items = node.memory_items.strip()
@@ -999,11 +995,15 @@ class EntorhinalCortex:
last_modified = node.last_modified or current_time
# 获取权重属性
weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
weight = node.weight if hasattr(node, "weight") and node.weight is not None else 1.0
# 添加节点到图中
self.memory_graph.G.add_node(
concept, memory_items=memory_items, weight=weight, created_time=created_time, last_modified=last_modified
concept,
memory_items=memory_items,
weight=weight,
created_time=created_time,
last_modified=last_modified,
)
loaded_nodes += 1
except Exception as e:
@@ -1044,9 +1044,11 @@ class EntorhinalCortex:
if need_update:
logger.info("[数据库] 已为缺失的时间字段进行补充")
# 输出加载统计信息
logger.info(f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}")
logger.info(
f"[数据库] 记忆加载完成: 总计 {total_nodes} 个节点, 成功加载 {loaded_nodes} 个, 跳过 {skipped_nodes}"
)
# 负责整合,遗忘,合并记忆
@@ -1054,10 +1056,12 @@ class ParahippocampalGyrus:
def __init__(self, hippocampus: Hippocampus):
self.hippocampus = hippocampus
self.memory_graph = hippocampus.memory_graph
self.memory_modify_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="memory.modify")
async def memory_compress(self, messages: list, compress_rate=0.1):
self.memory_modify_model = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="memory.modify"
)
async def memory_compress(self, messages: list[DatabaseMessages], compress_rate=0.1):
"""压缩和总结消息内容,生成记忆主题和摘要。
Args:
@@ -1083,7 +1087,6 @@ class ParahippocampalGyrus:
# build_readable_messages 只返回一个字符串,不需要解包
input_text = build_readable_messages(
messages,
merge_messages=True, # 合并连续消息
timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式
replace_bot_name=False, # 保留原始用户名
)
@@ -1163,7 +1166,7 @@ class ParahippocampalGyrus:
similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:3]
similar_topics_dict[topic] = similar_topics
if global_config.debug.show_prompt:
logger.info(f"prompt: {topic_what_prompt}")
logger.info(f"压缩后的记忆: {compressed_memory}")
@@ -1259,14 +1262,14 @@ class ParahippocampalGyrus:
# --- 如果节点不为空,则执行原来的不活跃检查和随机移除逻辑 ---
last_modified = node_data.get("last_modified", current_time)
node_weight = node_data.get("weight", 1.0)
# 条件1检查是否长时间未修改 (使用配置的遗忘时间)
time_threshold = 3600 * global_config.memory.memory_forget_time
# 基于权重调整遗忘阈值:权重越高,需要更长时间才能被遗忘
# 权重为1时使用默认阈值权重越高阈值越大越难遗忘
adjusted_threshold = time_threshold * node_weight
if current_time - last_modified > adjusted_threshold and memory_items:
# 既然每个节点现在是完整记忆,直接删除整个节点
try:
@@ -1315,8 +1318,6 @@ class ParahippocampalGyrus:
logger.info(f"[遗忘] 总耗时: {end_time - start_time:.2f}")
class HippocampusManager:
def __init__(self):
self._hippocampus: Hippocampus = None # type: ignore
@@ -1361,29 +1362,32 @@ class HippocampusManager:
"""为指定chat_id构建记忆在heartFC_chat.py中调用"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
# 检查是否需要构建记忆
logger.info(f"{chat_id} 构建记忆")
if memory_segment_manager.check_and_build_memory_for_chat(chat_id):
logger.info(f"{chat_id} 构建记忆,需要构建记忆")
messages = memory_segment_manager.get_messages_for_memory_build(chat_id, 50)
build_probability = 0.3 * global_config.memory.memory_build_frequency
if messages and random.random() < build_probability:
logger.info(f"{chat_id} 构建记忆,消息数量: {len(messages)}")
# 调用记忆压缩和构建
compressed_memory, similar_topics_dict = await self._hippocampus.parahippocampal_gyrus.memory_compress(
(
compressed_memory,
similar_topics_dict,
) = await self._hippocampus.parahippocampal_gyrus.memory_compress(
messages, global_config.memory.memory_compress_rate
)
# 添加记忆节点
current_time = time.time()
for topic, memory in compressed_memory:
await self._hippocampus.memory_graph.add_dot(topic, memory, self._hippocampus)
# 连接相似主题
if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic]
@@ -1391,23 +1395,23 @@ class HippocampusManager:
if topic != similar_topic:
strength = int(similarity * 10)
self._hippocampus.memory_graph.G.add_edge(
topic, similar_topic,
topic,
similar_topic,
strength=strength,
created_time=current_time,
last_modified=current_time
last_modified=current_time,
)
# 同步到数据库
await self._hippocampus.entorhinal_cortex.sync_memory_to_db()
logger.info(f"{chat_id} 构建记忆完成")
return True
except Exception as e:
logger.error(f"{chat_id} 构建记忆失败: {e}")
return False
return False
return False
async def get_memory_from_topic(
self, valid_keywords: list[str], max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3
@@ -1424,16 +1428,20 @@ class HippocampusManager:
response = []
return response
async def get_activate_from_text(self, text: str, max_depth: int = 3, fast_retrieval: bool = False) -> tuple[float, list[str]]:
async def get_activate_from_text(
self, text: str, max_depth: int = 3, fast_retrieval: bool = False
) -> tuple[float, list[str], list[str]]:
"""从文本中获取激活值的公共接口"""
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
response, keywords,keywords_lite = await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
text, max_depth, fast_retrieval
)
except Exception as e:
logger.error(f"文本产生激活值失败: {e}")
logger.error(traceback.format_exc())
return 0.0, [],[]
return 0.0, [], []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆的公共接口"""
@@ -1455,81 +1463,78 @@ hippocampus_manager = HippocampusManager()
# 在Hippocampus类中添加新的记忆构建管理器
class MemoryBuilder:
"""记忆构建器
为每个chat_id维护消息缓存和触发机制类似ExpressionLearner
"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.last_update_time: float = time.time()
self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool:
"""检查是否应该触发记忆构建"""
current_time = time.time()
# 检查时间间隔
time_diff = current_time - self.last_update_time
if time_diff < 600 /global_config.memory.memory_build_frequency:
if time_diff < 600 / global_config.memory.memory_build_frequency:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
)
logger.info(f"最近消息数量: {len(recent_messages)},间隔时间: {time_diff}")
if not recent_messages or len(recent_messages) < 30/global_config.memory.memory_build_frequency :
if not recent_messages or len(recent_messages) < 30 / global_config.memory.memory_build_frequency:
return False
return True
def get_messages_for_memory_build(self, threshold: int = 25) -> List[Dict[str, Any]]:
def get_messages_for_memory_build(self, threshold: int = 25) -> List[DatabaseMessages]:
"""获取用于记忆构建的消息"""
current_time = time.time()
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_update_time,
timestamp_end=current_time,
limit=threshold,
)
tmp_msg = [msg.__dict__ for msg in messages] if messages else []
if messages:
# 更新最后处理时间
self.last_processed_time = current_time
self.last_update_time = current_time
return tmp_msg or []
return messages or []
class MemorySegmentManager:
"""记忆段管理器
管理所有chat_id的MemoryBuilder实例自动检查和触发记忆构建
"""
def __init__(self):
self.builders: Dict[str, MemoryBuilder] = {}
def get_or_create_builder(self, chat_id: str) -> MemoryBuilder:
"""获取或创建指定chat_id的MemoryBuilder"""
if chat_id not in self.builders:
self.builders[chat_id] = MemoryBuilder(chat_id)
return self.builders[chat_id]
def check_and_build_memory_for_chat(self, chat_id: str) -> bool:
"""检查指定chat_id是否需要构建记忆如果需要则返回True"""
builder = self.get_or_create_builder(chat_id)
return builder.should_trigger_memory_build()
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[Dict[str, Any]]:
def get_messages_for_memory_build(self, chat_id: str, threshold: int = 25) -> List[DatabaseMessages]:
"""获取指定chat_id用于记忆构建的消息"""
if chat_id not in self.builders:
return []
@@ -1538,4 +1543,3 @@ class MemorySegmentManager:
# 创建全局实例
memory_segment_manager = MemorySegmentManager()

View File

@@ -1,17 +1,17 @@
import json
import random
from json_repair import repair_json
from typing import List, Tuple
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.utils.utils import parse_keywords_string
from src.chat.utils.chat_message_builder import build_readable_messages
import random
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.llm_models.utils_model import LLMRequest
logger = get_logger("memory_activator")
@@ -75,19 +75,20 @@ class MemoryActivator:
request_type="memory.selection",
)
async def activate_memory_with_chat_history(self, target_message, chat_history_prompt) -> List[Tuple[str, str]]:
async def activate_memory_with_chat_history(
self, target_message, chat_history: List[DatabaseMessages]
) -> List[Tuple[str, str]]:
"""
激活记忆
"""
# 如果记忆系统被禁用,直接返回空列表
if not global_config.memory.enable_memory:
return []
keywords_list = set()
for msg in chat_history_prompt:
keywords = parse_keywords_string(msg.get("key_words", ""))
for msg in chat_history:
keywords = parse_keywords_string(msg.key_words)
if keywords:
if len(keywords_list) < 30:
# 最多容纳30个关键词
@@ -95,24 +96,22 @@ class MemoryActivator:
logger.debug(f"提取关键词: {keywords_list}")
else:
break
if not keywords_list:
logger.debug("没有提取到关键词,返回空记忆列表")
return []
# 从海马体获取相关记忆
related_memory = await hippocampus_manager.get_memory_from_topic(
valid_keywords=list(keywords_list), max_memory_num=5, max_memory_length=3, max_depth=3
)
# logger.info(f"当前记忆关键词: {keywords_list}")
logger.debug(f"获取到的记忆: {related_memory}")
if not related_memory:
logger.debug("海马体没有返回相关记忆")
return []
used_ids = set()
candidate_memories = []
@@ -120,12 +119,7 @@ class MemoryActivator:
# 为每个记忆分配随机ID并过滤相关记忆
for memory in related_memory:
keyword, content = memory
found = False
for kw in keywords_list:
if kw in content:
found = True
break
found = any(kw in content for kw in keywords_list)
if found:
# 随机分配一个不重复的2位数id
while True:
@@ -138,95 +132,83 @@ class MemoryActivator:
if not candidate_memories:
logger.info("没有找到相关的候选记忆")
return []
# 如果只有少量记忆,直接返回
if len(candidate_memories) <= 2:
logger.debug(f"候选记忆较少({len(candidate_memories)}个),直接返回")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories]
# 使用 LLM 选择合适的记忆
selected_memories = await self._select_memories_with_llm(target_message, chat_history_prompt, candidate_memories)
return selected_memories
async def _select_memories_with_llm(self, target_message, chat_history_prompt, candidate_memories) -> List[Tuple[str, str]]:
return await self._select_memories_with_llm(target_message, chat_history, candidate_memories)
async def _select_memories_with_llm(
self, target_message, chat_history: List[DatabaseMessages], candidate_memories
) -> List[Tuple[str, str]]:
"""
使用 LLM 选择合适的记忆
Args:
target_message: 目标消息
chat_history_prompt: 聊天历史
candidate_memories: 候选记忆列表,每个记忆包含 memory_id、keyword、content
Returns:
List[Tuple[str, str]]: 选择的记忆列表,格式为 (keyword, content)
"""
try:
# 构建聊天历史字符串
obs_info_text = build_readable_messages(
chat_history_prompt,
chat_history,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 构建记忆信息字符串
memory_lines = []
for memory in candidate_memories:
memory_id = memory["memory_id"]
keyword = memory["keyword"]
content = memory["content"]
# 将 content 列表转换为字符串
if isinstance(content, list):
content_str = " | ".join(str(item) for item in content)
else:
content_str = str(content)
memory_lines.append(f"记忆编号 {memory_id}: [关键词: {keyword}] {content_str}")
memory_info = "\n".join(memory_lines)
# 获取并格式化 prompt
prompt_template = await global_prompt_manager.get_prompt_async("memory_activator_prompt")
formatted_prompt = prompt_template.format(
obs_info_text=obs_info_text,
target_message=target_message,
memory_info=memory_info
obs_info_text=obs_info_text, target_message=target_message, memory_info=memory_info
)
# 调用 LLM
response, (reasoning_content, model_name, _) = await self.memory_selection_model.generate_response_async(
formatted_prompt,
temperature=0.3,
max_tokens=150
formatted_prompt, temperature=0.3, max_tokens=150
)
if global_config.debug.show_prompt:
logger.info(f"记忆选择 prompt: {formatted_prompt}")
logger.info(f"LLM 记忆选择响应: {response}")
else:
logger.debug(f"记忆选择 prompt: {formatted_prompt}")
logger.debug(f"LLM 记忆选择响应: {response}")
# 解析响应获取选择的记忆编号
try:
fixed_json = repair_json(response)
# 解析为 Python 对象
result = json.loads(fixed_json) if isinstance(fixed_json, str) else fixed_json
# 提取 memory_ids 字段
memory_ids_str = result.get("memory_ids", "")
# 解析逗号分隔的编号
if memory_ids_str:
# 提取 memory_ids 字段并解析逗号分隔的编号
if memory_ids_str := result.get("memory_ids", ""):
memory_ids = [mid.strip() for mid in str(memory_ids_str).split(",") if mid.strip()]
# 过滤掉空字符串和无效编号
valid_memory_ids = [mid for mid in memory_ids if mid and len(mid) <= 3]
@@ -236,26 +218,24 @@ class MemoryActivator:
except Exception as e:
logger.error(f"解析记忆选择响应失败: {e}", exc_info=True)
selected_memory_ids = []
# 根据编号筛选记忆
selected_memories = []
memory_id_to_memory = {mem["memory_id"]: mem for mem in candidate_memories}
for memory_id in selected_memory_ids:
if memory_id in memory_id_to_memory:
selected_memories.append(memory_id_to_memory[memory_id])
selected_memories = [
memory_id_to_memory[memory_id] for memory_id in selected_memory_ids if memory_id in memory_id_to_memory
]
logger.info(f"LLM 选择的记忆编号: {selected_memory_ids}")
logger.info(f"最终选择的记忆数量: {len(selected_memories)}")
# 转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in selected_memories]
except Exception as e:
logger.error(f"LLM 选择记忆时出错: {e}", exc_info=True)
# 出错时返回前3个候选记忆作为备选转换为 (keyword, content) 格式
return [(mem["keyword"], mem["content"]) for mem in candidate_memories[:3]]
init_prompt()

View File

@@ -70,13 +70,10 @@ class ActionModifier:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
chat_content = build_readable_messages(
temp_msg_list_before_now_half,
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,

View File

@@ -283,11 +283,8 @@ class ActionPlanner:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now]
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=temp_msg_list_before_now,
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.last_obs_time_mark,
truncate=True,

View File

@@ -8,6 +8,7 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
@@ -156,7 +157,7 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
enable_tool: bool = True,
from_plugin: bool = True,
stream_id: Optional[str] = None,
@@ -171,7 +172,7 @@ class DefaultReplyer:
extra_info: 额外信息,用于补充上下文
reply_reason: 回复原因
available_actions: 可用的动作信息字典
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_tool: 是否启用工具调用
from_plugin: 是否来自插件
@@ -189,7 +190,7 @@ class DefaultReplyer:
prompt, selected_expressions = await self.build_prompt_reply_context(
extra_info=extra_info,
available_actions=available_actions,
choosen_actions=choosen_actions,
chosen_actions=chosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
@@ -296,7 +297,7 @@ class DefaultReplyer:
if not sender:
return ""
if sender == global_config.bot.nickname:
return ""
@@ -352,7 +353,7 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_memory_block(self, chat_history: List[Dict[str, Any]], target: str) -> str:
async def build_memory_block(self, chat_history: List[DatabaseMessages], target: str) -> str:
"""构建记忆块
Args:
@@ -369,7 +370,7 @@ class DefaultReplyer:
instant_memory = None
running_memories = await self.memory_activator.activate_memory_with_chat_history(
target_message=target, chat_history_prompt=chat_history
target_message=target, chat_history=chat_history
)
if global_config.memory.enable_instant_memory:
@@ -433,7 +434,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
def _parse_reply_target(self, target_message: Optional[str]) -> Tuple[str, str]:
"""解析回复目标消息
Args:
@@ -514,7 +515,7 @@ class DefaultReplyer:
return name, result, duration
def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str, sender: str
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
@@ -530,16 +531,16 @@ class DefaultReplyer:
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
for msg_dict in message_list_before_now:
for msg in message_list_before_now:
try:
msg_user_id = str(msg_dict.get("user_id"))
reply_to = msg_dict.get("reply_to", "")
msg_user_id = str(msg.user_info.user_id)
reply_to = msg.reply_to
_platform, reply_to_user_id = self._parse_reply_target(reply_to)
if (msg_user_id == bot_id and reply_to_user_id == target_user_id) or msg_user_id == target_user_id:
# bot 和目标用户的对话
core_dialogue_list.append(msg_dict)
core_dialogue_list.append(msg)
except Exception as e:
logger.error(f"处理消息记录时出错: {msg_dict}, 错误: {e}")
logger.error(f"处理消息记录时出错: {msg}, 错误: {e}")
# 构建背景对话 prompt
all_dialogue_prompt = ""
@@ -574,7 +575,6 @@ class DefaultReplyer:
core_dialogue_prompt_str = build_readable_messages(
core_dialogue_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="normal_no_YMD",
read_mark=0.0,
truncate=True,
@@ -668,7 +668,7 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
enable_tool: bool = True,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[str, List[int]]:
@@ -679,7 +679,7 @@ class DefaultReplyer:
extra_info: 额外信息,用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_timeout: 是否启用超时处理
enable_tool: 是否启用工具调用
reply_message: 回复的原始消息
@@ -712,27 +712,21 @@ class DefaultReplyer:
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size * 1,
)
temp_msg_list_before_long = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_long]
# TODO: 修复!
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
)
temp_msg_list_before_short = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_short]
chat_talking_prompt_short = build_readable_messages(
temp_msg_list_before_short,
message_list_before_short,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
@@ -744,12 +738,12 @@ class DefaultReplyer:
self.build_expression_habits(chat_talking_prompt_short, target), "expression_habits"
),
self._time_and_run_task(self.build_relation_info(sender, target), "relation_info"),
self._time_and_run_task(self.build_memory_block(temp_msg_list_before_short, target), "memory_block"),
self._time_and_run_task(self.build_memory_block(message_list_before_short, target), "memory_block"),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
),
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, choosen_actions), "actions_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
)
# 任务名称中英文映射
@@ -810,25 +804,9 @@ class DefaultReplyer:
else:
reply_target_block = ""
# if is_group_chat:
# chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
# chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
# else:
# chat_target_name = "对方"
# if self.chat_target_info:
# chat_target_name = (
# self.chat_target_info.get("person_name") or self.chat_target_info.get("user_nickname") or "对方"
# )
# chat_target_1 = await global_prompt_manager.format_prompt(
# "chat_target_private1", sender_name=chat_target_name
# )
# chat_target_2 = await global_prompt_manager.format_prompt(
# "chat_target_private2", sender_name=chat_target_name
# )
# 构建分离的对话 prompt
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
temp_msg_list_before_long, user_id, sender
message_list_before_now_long, user_id, sender
)
if global_config.bot.qq_account == user_id and platform == global_config.bot.platform:
@@ -879,7 +857,7 @@ class DefaultReplyer:
reason: str,
reply_to: str,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[str, List[int]]: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
@@ -902,20 +880,16 @@ class DefaultReplyer:
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
)
# TODO: 修复!
from src.common.data_models import temporarily_transform_class_to_dict
temp_msg_list_before_now_half = [temporarily_transform_class_to_dict(msg) for msg in message_list_before_now_half]
chat_talking_prompt_half = build_readable_messages(
temp_msg_list_before_now_half,
message_list_before_now_half,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
show_actions=True,
)
# 并行执行2个构建任务
(expression_habits_block, selected_expressions), relation_info = await asyncio.gather(
(expression_habits_block, _), relation_info = await asyncio.gather(
self.build_expression_habits(chat_talking_prompt_half, target),
self.build_relation_info(sender, target),
)

View File

@@ -1,4 +1,4 @@
import time # 导入 time 模块以获取当前时间
import time
import random
import re
@@ -6,14 +6,17 @@ from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids
install(extra_lines=3)
logger = get_logger("chat_message_builder")
def replace_user_references_sync(
@@ -349,7 +352,9 @@ def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[DatabaseMessages]:
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
@@ -390,16 +395,16 @@ def num_new_messages_since_with_users(
def _build_readable_messages_internal(
messages: List[Dict[str, Any]],
messages: List[MessageAndActionModel],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
pic_id_mapping: Optional[Dict[str, str]] = None,
pic_counter: int = 1,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: Optional[List[DatabaseMessages]] = None,
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
# sourcery skip: use-getitem-for-re-match-groups
"""
内部辅助函数,构建可读消息字符串和原始消息详情列表。
@@ -418,7 +423,7 @@ def _build_readable_messages_internal(
if not messages:
return "", [], pic_id_mapping or {}, pic_counter
message_details_raw: List[Tuple[float, str, str, bool]] = []
detailed_messages_raw: List[Tuple[float, str, str, bool]] = []
# 使用传入的映射字典,如果没有则创建新的
if pic_id_mapping is None:
@@ -426,25 +431,26 @@ def _build_readable_messages_internal(
current_pic_counter = pic_counter
# 创建时间戳到消息ID的映射用于在消息前添加[id]标识符
timestamp_to_id = {}
timestamp_to_id_mapping: Dict[float, str] = {}
if message_id_list:
for item in message_id_list:
message = item.get("message", {})
timestamp = message.get("time")
for msg in message_id_list:
timestamp = msg.time
if timestamp is not None:
timestamp_to_id[timestamp] = item.get("id", "")
timestamp_to_id_mapping[timestamp] = msg.message_id
def process_pic_ids(content: str) -> str:
def process_pic_ids(content: Optional[str]) -> str:
"""处理内容中的图片ID将其替换为[图片x]格式"""
nonlocal current_pic_counter
if content is None:
logger.warning("Content is None when processing pic IDs.")
raise ValueError("Content is None")
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match):
def replace_pic_id(match: re.Match) -> str:
nonlocal current_pic_counter
nonlocal pic_counter
pic_id = match.group(1)
if pic_id not in pic_id_mapping:
pic_id_mapping[pic_id] = f"图片{current_pic_counter}"
current_pic_counter += 1
@@ -453,42 +459,23 @@ def _build_readable_messages_internal(
return re.sub(pic_pattern, replace_pic_id, content)
# 1 & 2: 获取发送者信息并提取消息组件
for msg in messages:
# 检查是否是动作记录
if msg.get("is_action_record", False):
is_action = True
timestamp: float = msg.get("time") # type: ignore
content = msg.get("display_message", "")
# 1: 获取发送者信息并提取消息组件
for message in messages:
if message.is_action_record:
# 对于动作记录也处理图片ID
content = process_pic_ids(content)
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
content = process_pic_ids(message.display_message)
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
continue
# 检查并修复缺少的user_info字段
if "user_info" not in msg:
# 创建user_info字段
msg["user_info"] = {
"platform": msg.get("user_platform", ""),
"user_id": msg.get("user_id", ""),
"user_nickname": msg.get("user_nickname", ""),
"user_cardname": msg.get("user_cardname", ""),
}
platform = message.user_platform
user_id = message.user_id
user_nickname = message.user_nickname
user_cardname = message.user_cardname
user_info = msg.get("user_info", {})
platform = user_info.get("platform")
user_id = user_info.get("user_id")
user_nickname = user_info.get("user_nickname")
user_cardname = user_info.get("user_cardname")
timestamp: float = msg.get("time") # type: ignore
content: str
if msg.get("display_message"):
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "") # 默认空字符串
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
# 向下兼容
if "" in content:
content = content.replace("", "")
if "" in content:
@@ -504,52 +491,32 @@ def _build_readable_messages_internal(
person = Person(platform=platform, user_id=user_id)
# 根据 replace_bot_name 参数决定是否替换机器人名称
person_name: str
person_name = (
person.person_name or f"{user_nickname}" or (f"昵称:{user_cardname}" if user_cardname else "某人")
)
if replace_bot_name and user_id == global_config.bot.qq_account:
person_name = f"{global_config.bot.nickname}(你)"
else:
person_name = person.person_name or user_id # type: ignore
# 如果 person_name 未设置,则使用消息中的 nickname 或默认名称
if not person_name:
if user_cardname:
person_name = f"昵称:{user_cardname}"
elif user_nickname:
person_name = f"{user_nickname}"
else:
person_name = "某人"
# 使用独立函数处理用户引用格式
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name):
detailed_messages_raw.append((timestamp, person_name, content, False))
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content and random.random() < 0.6:
content = content.replace(target_str, "")
if content != "":
message_details_raw.append((timestamp, person_name, content, False))
if not message_details_raw:
if not detailed_messages_raw:
return "", [], pic_id_mapping, current_pic_counter
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
detailed_messages_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
detailed_message: List[Tuple[float, str, str, bool]] = []
# 为每条消息添加一个标记,指示它是否是动作记录
message_details_with_flags = []
for timestamp, name, content, is_action in message_details_raw:
message_details_with_flags.append((timestamp, name, content, is_action))
# 应用截断逻辑 (如果 truncate 为 True)
message_details: List[Tuple[float, str, str, bool]] = []
n_messages = len(message_details_with_flags)
if truncate and n_messages > 0:
for i, (timestamp, name, content, is_action) in enumerate(message_details_with_flags):
# 2. 应用消息截断逻辑
messages_count = len(detailed_messages_raw)
if truncate and messages_count > 0:
for i, (timestamp, name, content, is_action) in enumerate(detailed_messages_raw):
# 对于动作记录,不进行截断
if is_action:
message_details.append((timestamp, name, content, is_action))
detailed_message.append((timestamp, name, content, is_action))
continue
percentile = i / n_messages # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
percentile = i / messages_count # 计算消息在列表中的位置百分比 (0 <= percentile < 1)
original_len = len(content)
limit = -1 # 默认不截断
@@ -562,116 +529,42 @@ def _build_readable_messages_internal(
elif percentile < 0.7: # 60% 到 80% 之前的消息 (即中间的 20%)
limit = 200
replace_content = "......(内容太长了)"
elif percentile < 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
elif percentile <= 1.0: # 80% 到 100% 之前的消息 (即较新的 20%)
limit = 400
replace_content = "......(太长了)"
replace_content = "......内容太长了)"
truncated_content = content
if 0 < limit < original_len:
truncated_content = f"{content[:limit]}{replace_content}"
message_details.append((timestamp, name, truncated_content, is_action))
detailed_message.append((timestamp, name, truncated_content, is_action))
else:
# 如果不截断,直接使用原始列表
message_details = message_details_with_flags
detailed_message = detailed_messages_raw
# 3: 合并连续消息 (如果 merge_messages 为 True)
merged_messages = []
if merge_messages and message_details:
# 初始化第一个合并块
current_merge = {
"name": message_details[0][1],
"start_time": message_details[0][0],
"end_time": message_details[0][0],
"content": [message_details[0][2]],
"is_action": message_details[0][3],
}
# 3: 格式化为字符串
output_lines: List[str] = []
for i in range(1, len(message_details)):
timestamp, name, content, is_action = message_details[i]
for timestamp, name, content, is_action in detailed_message:
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 对于动作记录,不进行合并
if is_action or current_merge["is_action"]:
# 保存当前的合并块
merged_messages.append(current_merge)
# 创建新的块
current_merge = {
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action,
}
continue
# 查找消息id如果有并构建id_prefix
message_id = timestamp_to_id_mapping.get(timestamp)
id_prefix = f"[{message_id}]" if message_id else ""
# 如果是同一个人发送的连续消息且时间间隔小于等于60秒
if name == current_merge["name"] and (timestamp - current_merge["end_time"] <= 60):
current_merge["content"].append(content)
current_merge["end_time"] = timestamp # 更新最后消息时间
else:
# 保存上一个合并块
merged_messages.append(current_merge)
# 开始新的合并块
current_merge = {
"name": name,
"start_time": timestamp,
"end_time": timestamp,
"content": [content],
"is_action": is_action,
}
# 添加最后一个合并块
merged_messages.append(current_merge)
elif message_details: # 如果不合并消息,则每个消息都是一个独立的块
for timestamp, name, content, is_action in message_details:
merged_messages.append(
{
"name": name,
"start_time": timestamp, # 起始和结束时间相同
"end_time": timestamp,
"content": [content], # 内容只有一个元素
"is_action": is_action,
}
)
# 4 & 5: 格式化为字符串
output_lines = []
for _i, merged in enumerate(merged_messages):
# 使用指定的 timestamp_mode 格式化时间
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
# 查找对应的消息ID
message_id = timestamp_to_id.get(merged["start_time"], "")
id_prefix = f"[{message_id}] " if message_id else ""
# 检查是否是动作记录
if merged["is_action"]:
if is_action:
# 对于动作记录,使用特殊格式
output_lines.append(f"{id_prefix}{readable_time}, {merged['content'][0]}")
output_lines.append(f"{id_prefix}{readable_time}, {content}")
else:
header = f"{id_prefix}{readable_time}, {merged['name']} :"
output_lines.append(header)
# 将内容合并,并添加缩进
for line in merged["content"]:
stripped_line = line.strip()
if stripped_line: # 过滤空行
# 移除末尾句号,添加分号 - 这个逻辑似乎有点奇怪,暂时保留
if stripped_line.endswith(""):
stripped_line = stripped_line[:-1]
# 如果内容被截断,结尾已经是 ...(内容太长),不再添加分号
if not stripped_line.endswith("(内容太长)"):
output_lines.append(f"{stripped_line}")
else:
output_lines.append(stripped_line) # 直接添加截断后的内容
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
# 移除可能的多余换行,然后合并
formatted_string = "".join(output_lines).strip()
# 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器
return (
formatted_string,
[(t, n, c) for t, n, c, is_action in message_details if not is_action],
[(t, n, c) for t, n, c, is_action in detailed_message if not is_action],
pic_id_mapping,
current_pic_counter,
)
@@ -755,9 +648,8 @@ def build_readable_actions(actions: List[Dict[str, Any]]) -> str:
async def build_readable_messages_with_list(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
truncate: bool = False,
) -> Tuple[str, List[Tuple[float, str, str]]]:
@@ -766,7 +658,7 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为。
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
@@ -776,15 +668,14 @@ async def build_readable_messages_with_list(
def build_readable_messages_with_id(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
) -> Tuple[str, List[Dict[str, Any]]]:
) -> Tuple[str, List[DatabaseMessages]]:
"""
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
允许通过参数控制格式化行为。
@@ -794,7 +685,6 @@ def build_readable_messages_with_id(
formatted_string = build_readable_messages(
messages=messages,
replace_bot_name=replace_bot_name,
merge_messages=merge_messages,
timestamp_mode=timestamp_mode,
truncate=truncate,
show_actions=show_actions,
@@ -807,15 +697,14 @@ def build_readable_messages_with_id(
def build_readable_messages(
messages: List[Dict[str, Any]],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
show_pic: bool = True,
message_id_list: Optional[List[Dict[str, Any]]] = None,
message_id_list: Optional[List[DatabaseMessages]] = None,
) -> str: # sourcery skip: extract-method
"""
将消息列表转换为可读的文本格式。
@@ -831,19 +720,32 @@ def build_readable_messages(
truncate: 是否截断长消息
show_actions: 是否显示动作记录
"""
# WIP HERE and BELOW ----------------------------------------------
# 创建messages的深拷贝避免修改原始列表
if not messages:
return ""
copy_messages = [msg.copy() for msg in messages]
copy_messages: List[MessageAndActionModel] = [
MessageAndActionModel(
msg.time,
msg.user_info.user_id,
msg.user_info.platform,
msg.user_info.user_nickname,
msg.user_info.user_cardname,
msg.processed_plain_text,
msg.display_message,
msg.chat_info.platform,
)
for msg in messages
]
if show_actions and copy_messages:
# 获取所有消息的时间范围
min_time = min(msg.get("time", 0) for msg in copy_messages)
max_time = max(msg.get("time", 0) for msg in copy_messages)
min_time = min(msg.time or 0 for msg in copy_messages)
max_time = max(msg.time or 0 for msg in copy_messages)
# 从第一条消息中获取chat_id
chat_id = copy_messages[0].get("chat_id") if copy_messages else None
chat_id = messages[0].chat_id if messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
@@ -863,34 +765,34 @@ def build_readable_messages(
)
# 合并两部分动作记录
actions = list(actions_in_range) + list(action_after_latest)
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
if action.action_build_into_prompt:
action_msg = {
"time": action.time,
"user_id": global_config.bot.qq_account, # 使用机器人的QQ账号
"user_nickname": global_config.bot.nickname, # 使用机器人的昵称
"user_cardname": "", # 机器人没有群名片
"processed_plain_text": f"{action.action_prompt_display}",
"display_message": f"{action.action_prompt_display}",
"chat_info_platform": action.chat_info_platform,
"is_action_record": True, # 添加标识字段
"action_name": action.action_name, # 保存动作名称
}
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
user_platform=global_config.bot.platform, # 使用机器人的平台
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
user_cardname="", # 机器人没有群名片
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True, # 添加标识字段
action_name=str(action.action_name), # 保存动作名称
)
copy_messages.append(action_msg)
# 重新按时间排序
copy_messages.sort(key=lambda x: x.get("time", 0))
copy_messages.sort(key=lambda x: x.time or 0)
if read_mark <= 0:
# 没有有效的 read_mark直接格式化所有消息
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
copy_messages,
replace_bot_name,
merge_messages,
timestamp_mode,
truncate,
show_pic=show_pic,
@@ -905,8 +807,8 @@ def build_readable_messages(
return formatted_string
else:
# 按 read_mark 分割消息
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
messages_before_mark = [msg for msg in copy_messages if (msg.time or 0) <= read_mark]
messages_after_mark = [msg for msg in copy_messages if (msg.time or 0) > read_mark]
# 共享的图片映射字典和计数器
pic_id_mapping = {}
@@ -916,7 +818,6 @@ def build_readable_messages(
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
messages_before_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
truncate,
pic_id_mapping,
@@ -927,7 +828,6 @@ def build_readable_messages(
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
messages_after_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
False,
pic_id_mapping,
@@ -960,13 +860,13 @@ def build_readable_messages(
return "".join(result_parts)
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
"""
构建匿名可读消息将不同人的名称转为唯一占位符A、B、C...bot自己用SELF。
处理 回复<aaa:bbb> 和 @<aaa:bbb> 字段将bbb映射为匿名占位符。
"""
if not messages:
print("111111111111没有消息,无法构建匿名消息")
logger.warning("没有消息,无法构建匿名消息")
return ""
person_map = {}
@@ -1017,14 +917,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
for msg in messages:
try:
platform: str = msg.get("chat_info_platform") # type: ignore
user_id = msg.get("user_id")
_timestamp = msg.get("time")
content: str = ""
if msg.get("display_message"):
content = msg.get("display_message", "")
else:
content = msg.get("processed_plain_text", "")
platform = msg.chat_info.platform
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
if "" in content:
content = content.replace("", "")
@@ -1101,3 +996,22 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
"""
将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。
"""
return [
MessageAndActionModel(
time=msg.time,
user_id=msg.user_info.user_id,
user_platform=msg.user_info.platform,
user_nickname=msg.user_info.user_nickname,
user_cardname=msg.user_info.user_cardname,
processed_plain_text=msg.processed_plain_text,
display_message=msg.display_message,
chat_info_platform=msg.chat_info.platform,
)
for msg in message
]

View File

@@ -12,6 +12,7 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
@@ -113,6 +114,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)
@@ -151,10 +153,13 @@ def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> li
if (
(db_msg.user_info.platform, db_msg.user_info.user_id) != sender
and db_msg.user_info.user_id != global_config.bot.qq_account
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname) not in who_chat_in_group
and (db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
not in who_chat_in_group
and len(who_chat_in_group) < 5
): # 排除重复排除消息发送者排除bot限制加载的关系数目
who_chat_in_group.append((db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname))
who_chat_in_group.append(
(db_msg.user_info.platform, db_msg.user_info.user_id, db_msg.user_info.user_nickname)
)
return who_chat_in_group
@@ -640,9 +645,9 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
target_info = TargetPersonInfo(
platform=platform,
user_id=user_id,
user_nickname=user_info.user_nickname, # type: ignore
user_nickname=user_info.user_nickname, # type: ignore
person_id=None,
person_name=None
person_name=None,
)
# Try to fetch person info
@@ -669,17 +674,17 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
return is_group_chat, chat_target_info
def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
def assign_message_ids(messages: List[DatabaseMessages]) -> List[DatabaseMessages]:
"""
为消息列表中的每个消息分配唯一的简短随机ID
Args:
messages: 消息列表
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
List[DatabaseMessages]: 分配了唯一ID的消息列表(写入message_id属性)
"""
result = []
result: List[DatabaseMessages] = list(messages) # 复制原始消息列表
used_ids = set()
len_i = len(messages)
if len_i > 100:
@@ -688,95 +693,86 @@ def assign_message_ids(messages: List[Any]) -> List[Dict[str, Any]]:
else:
a = 1
b = 9
for i, message in enumerate(messages):
for i, _ in enumerate(result):
# 生成唯一的简短ID
while True:
# 使用索引+随机数生成简短ID
random_suffix = random.randint(a, b)
message_id = f"m{i+1}{random_suffix}"
message_id = f"m{i + 1}{random_suffix}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
result[i].message_id = message_id
return result
def assign_message_ids_flexible(
messages: list,
prefix: str = "msg",
id_length: int = 6,
use_timestamp: bool = False
) -> list:
"""
为消息列表中的每个消息分配唯一的简短随机ID增强版
Args:
messages: 消息列表
prefix: ID前缀默认为"msg"
id_length: ID的总长度不包括前缀默认为6
use_timestamp: 是否在ID中包含时间戳默认为False
Returns:
包含 {'id': str, 'message': any} 格式的字典列表
"""
result = []
used_ids = set()
for i, message in enumerate(messages):
# 生成唯一的ID
while True:
if use_timestamp:
# 使用时间戳的后几位 + 随机字符
timestamp_suffix = str(int(time.time() * 1000))[-3:]
remaining_length = id_length - 3
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{timestamp_suffix}{random_chars}"
else:
# 使用索引 + 随机字符
index_str = str(i + 1)
remaining_length = max(1, id_length - len(index_str))
random_chars = ''.join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
message_id = f"{prefix}{index_str}{random_chars}"
if message_id not in used_ids:
used_ids.add(message_id)
break
result.append({
'id': message_id,
'message': message
})
return result
# def assign_message_ids_flexible(
# messages: list, prefix: str = "msg", id_length: int = 6, use_timestamp: bool = False
# ) -> list:
# """
# 为消息列表中的每个消息分配唯一的简短随机ID增强版
# Args:
# messages: 消息列表
# prefix: ID前缀默认为"msg"
# id_length: ID的总长度不包括前缀默认为6
# use_timestamp: 是否在ID中包含时间戳默认为False
# Returns:
# 包含 {'id': str, 'message': any} 格式的字典列表
# """
# result = []
# used_ids = set()
# for i, message in enumerate(messages):
# # 生成唯一的ID
# while True:
# if use_timestamp:
# # 使用时间戳的后几位 + 随机字符
# timestamp_suffix = str(int(time.time() * 1000))[-3:]
# remaining_length = id_length - 3
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{timestamp_suffix}{random_chars}"
# else:
# # 使用索引 + 随机字符
# index_str = str(i + 1)
# remaining_length = max(1, id_length - len(index_str))
# random_chars = "".join(random.choices(string.ascii_lowercase + string.digits, k=remaining_length))
# message_id = f"{prefix}{index_str}{random_chars}"
# if message_id not in used_ids:
# used_ids.add(message_id)
# break
# result.append({"id": message_id, "message": message})
# return result
# 使用示例:
# messages = ["Hello", "World", "Test message"]
#
#
# # 基础版本
# result1 = assign_message_ids(messages)
# # 结果: [{'id': 'm1123', 'message': 'Hello'}, {'id': 'm2456', 'message': 'World'}, {'id': 'm3789', 'message': 'Test message'}]
#
#
# # 增强版本 - 自定义前缀和长度
# result2 = assign_message_ids_flexible(messages, prefix="chat", id_length=8)
# # 结果: [{'id': 'chat1abc2', 'message': 'Hello'}, {'id': 'chat2def3', 'message': 'World'}, {'id': 'chat3ghi4', 'message': 'Test message'}]
#
#
# # 增强版本 - 使用时间戳
# result3 = assign_message_ids_flexible(messages, prefix="ts", use_timestamp=True)
# # 结果: [{'id': 'ts123a1b', 'message': 'Hello'}, {'id': 'ts123c2d', 'message': 'World'}, {'id': 'ts123e3f', 'message': 'Test message'}]
def parse_keywords_string(keywords_input) -> list[str]:
# sourcery skip: use-contextlib-suppress
"""
统一的关键词解析函数,支持多种格式的关键词字符串解析
支持的格式:
1. 字符串列表格式:'["utils.py", "修改", "代码", "动作"]'
2. 斜杠分隔格式:'utils.py/修改/代码/动作'
@@ -784,25 +780,25 @@ def parse_keywords_string(keywords_input) -> list[str]:
4. 空格分隔格式:'utils.py 修改 代码 动作'
5. 已经是列表的情况:["utils.py", "修改", "代码", "动作"]
6. JSON格式字符串'{"keywords": ["utils.py", "修改", "代码", "动作"]}'
Args:
keywords_input: 关键词输入,可以是字符串或列表
Returns:
list[str]: 解析后的关键词列表,去除空白项
"""
if not keywords_input:
return []
# 如果已经是列表,直接处理
if isinstance(keywords_input, list):
return [str(k).strip() for k in keywords_input if str(k).strip()]
# 转换为字符串处理
keywords_str = str(keywords_input).strip()
if not keywords_str:
return []
try:
# 尝试作为JSON对象解析支持 {"keywords": [...]} 格式)
json_data = json.loads(keywords_str)
@@ -815,7 +811,7 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [str(k).strip() for k in json_data if str(k).strip()]
except (json.JSONDecodeError, ValueError):
pass
try:
# 尝试使用 ast.literal_eval 解析支持Python字面量格式
parsed = ast.literal_eval(keywords_str)
@@ -823,15 +819,15 @@ def parse_keywords_string(keywords_input) -> list[str]:
return [str(k).strip() for k in parsed if str(k).strip()]
except (ValueError, SyntaxError):
pass
# 尝试不同的分隔符
separators = ['/', ',', ' ', '|', ';']
separators = ["/", ",", " ", "|", ";"]
for separator in separators:
if separator in keywords_str:
keywords_list = [k.strip() for k in keywords_str.split(separator) if k.strip()]
if len(keywords_list) > 1: # 确保分割有效
return keywords_list
# 如果没有分隔符,返回单个关键词
return [keywords_str] if keywords_str else []
return [keywords_str] if keywords_str else []