diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index 48122f88..f8e91b5c 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -127,20 +127,20 @@ class InstantMemory: from json_repair import repair_json prompt = f""" - 请根据以下发言内容,判断是否需要提取记忆 - {target} - 请用json格式输出,包含以下字段: - 其中,time的要求是: - 可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD - 可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前 - 可以选择留空进行模糊搜索 - {{ - "need_memory": 1, - "keywords": "希望获取的记忆关键词,用/划分", - "time": "希望获取的记忆大致时间" - }} - 请只输出json格式,不要输出其他多余内容 - """ +请根据以下发言内容,判断是否需要提取记忆 +{target} +请用json格式输出,包含以下字段: +其中,time的要求是: +可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD +可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前 +可以选择留空进行模糊搜索 +{{ + "need_memory": 1, + "keywords": "希望获取的记忆关键词,用/划分", + "time": "希望获取的记忆大致时间" +}} +请只输出json格式,不要输出其他多余内容 +""" try: response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) if global_config.debug.show_prompt: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index a79088da..51edd045 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -586,7 +586,10 @@ async def build_readable_messages_with_list( 允许通过参数控制格式化行为。 """ formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( - convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate + [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages], + replace_bot_name, + timestamp_mode, + truncate, ) if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): @@ -653,19 +656,7 @@ def build_readable_messages( if not messages: return "" - copy_messages: List[MessageAndActionModel] = [ - MessageAndActionModel( - msg.time, - msg.user_info.user_id, - msg.user_info.platform, - msg.user_info.user_nickname, - msg.user_info.user_cardname, - msg.processed_plain_text, - msg.display_message, - msg.chat_info.platform, - ) - for msg in messages - ] + copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages] if show_actions and copy_messages: # 获取所有消息的时间范围 @@ -924,22 +915,3 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 - - -def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]: - """ - 将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。 - """ - return [ - MessageAndActionModel( - time=msg.time, - user_id=msg.user_info.user_id, - user_platform=msg.user_info.platform, - user_nickname=msg.user_info.user_nickname, - user_cardname=msg.user_info.user_cardname, - processed_plain_text=msg.processed_plain_text, - display_message=msg.display_message, - chat_info_platform=msg.chat_info.platform, - ) - for msg in message - ] diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py index 0fa87ba0..8e0b7786 100644 --- a/src/common/data_models/message_data_model.py +++ b/src/common/data_models/message_data_model.py @@ -1,10 +1,15 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING from dataclasses import dataclass, field from . import BaseDataModel +if TYPE_CHECKING: + from .database_data_model import DatabaseMessages + + @dataclass class MessageAndActionModel(BaseDataModel): + chat_id: str = field(default_factory=str) time: float = field(default_factory=float) user_id: str = field(default_factory=str) user_platform: str = field(default_factory=str) @@ -15,3 +20,17 @@ class MessageAndActionModel(BaseDataModel): chat_info_platform: str = field(default_factory=str) is_action_record: bool = field(default=False) action_name: Optional[str] = None + + @classmethod + def from_DatabaseMessages(cls, message: "DatabaseMessages"): + return cls( + chat_id=message.chat_id, + time=message.time, + user_id=message.user_info.user_id, + user_platform=message.user_info.platform, + user_nickname=message.user_info.user_nickname, + user_cardname=message.user_info.user_cardname, + processed_plain_text=message.processed_plain_text, + display_message=message.display_message, + chat_info_platform=message.chat_info.platform, + ) diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 9bf484f0..7d2591ff 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -3,7 +3,8 @@ import traceback import os import pickle import random -from typing import List, Dict, Any +import asyncio +from typing import List, Dict, Any, TYPE_CHECKING from src.config.config import global_config from src.common.logger import get_logger from src.person_info.relationship_manager import get_relationship_manager @@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, num_new_messages_since, ) -import asyncio + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages logger = get_logger("relationship_builder") @@ -429,7 +432,7 @@ class RelationshipBuilder: if dropped_count > 0: logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段") - processed_messages = [] + processed_messages: List["DatabaseMessages"] = [] # 对筛选后的消息段进行排序,确保时间顺序 segments_to_process.sort(key=lambda x: x["start_time"]) @@ -449,17 +452,18 @@ class RelationshipBuilder: # 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识 if processed_messages: # 创建一个特殊的间隔消息 - gap_message = { - "time": start_time - 0.1, # 稍微早于段开始时间 - "user_id": "system", - "user_platform": "system", - "user_nickname": "系统", - "user_cardname": "", - "display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...", - "is_action_record": True, - "chat_info_platform": segment_messages[0].chat_info.platform or "", - "chat_id": chat_id, - } + gap_message = DatabaseMessages( + time=start_time - 0.1, + user_id="system", + user_platform="system", + user_nickname="系统", + user_cardname="", + display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...", + is_action_record=True, + chat_info_platform=segment_messages[0].chat_info.platform or "", + chat_id=chat_id, + ) + processed_messages.append(gap_message) # 添加该段的所有消息 @@ -467,11 +471,11 @@ class RelationshipBuilder: if processed_messages: # 按时间排序所有消息(包括间隔标识) - processed_messages.sort(key=lambda x: x["time"]) + processed_messages.sort(key=lambda x: x.time) logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新") relationship_manager = get_relationship_manager() - + build_frequency = 0.3 * global_config.relationship.relation_frequency if random.random() < build_frequency: # 调用原有的更新方法 diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 916162a8..0b5da6d5 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -3,16 +3,18 @@ import traceback from json_repair import repair_json from datetime import datetime -from typing import List +from typing import List, TYPE_CHECKING from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from .person_info import Person +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + logger = get_logger("relation") @@ -177,7 +179,7 @@ class RelationshipManager: return person - async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]): + async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]): """更新用户印象 Args: @@ -192,8 +194,6 @@ class RelationshipManager: # nickname = person.nickname know_times: float = person.know_times - user_messages = bot_engaged_messages - # 匿名化消息 # 创建用户名称映射 name_mapping = {} @@ -201,7 +201,7 @@ class RelationshipManager: user_count = 1 # 遍历消息,构建映射 - for msg in user_messages: + for msg in bot_engaged_messages: if msg.user_info.user_id == "system": continue try: @@ -233,7 +233,7 @@ class RelationshipManager: current_user = chr(ord(current_user) + 1) readable_messages = build_readable_messages( - messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True + messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True ) for original_name, mapped_name in name_mapping.items():