修复relationship_build
This commit is contained in:
@@ -127,20 +127,20 @@ class InstantMemory:
|
|||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请根据以下发言内容,判断是否需要提取记忆
|
请根据以下发言内容,判断是否需要提取记忆
|
||||||
{target}
|
{target}
|
||||||
请用json格式输出,包含以下字段:
|
请用json格式输出,包含以下字段:
|
||||||
其中,time的要求是:
|
其中,time的要求是:
|
||||||
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
可以选择具体日期时间,格式为YYYY-MM-DD HH:MM:SS,或者大致时间,格式为YYYY-MM-DD
|
||||||
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
可以选择相对时间,例如:今天,昨天,前天,5天前,1个月前
|
||||||
可以选择留空进行模糊搜索
|
可以选择留空进行模糊搜索
|
||||||
{{
|
{{
|
||||||
"need_memory": 1,
|
"need_memory": 1,
|
||||||
"keywords": "希望获取的记忆关键词,用/划分",
|
"keywords": "希望获取的记忆关键词,用/划分",
|
||||||
"time": "希望获取的记忆大致时间"
|
"time": "希望获取的记忆大致时间"
|
||||||
}}
|
}}
|
||||||
请只输出json格式,不要输出其他多余内容
|
请只输出json格式,不要输出其他多余内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
|
|||||||
@@ -586,7 +586,10 @@ async def build_readable_messages_with_list(
|
|||||||
允许通过参数控制格式化行为。
|
允许通过参数控制格式化行为。
|
||||||
"""
|
"""
|
||||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
|
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
|
||||||
|
replace_bot_name,
|
||||||
|
timestamp_mode,
|
||||||
|
truncate,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
|
||||||
@@ -653,19 +656,7 @@ def build_readable_messages(
|
|||||||
if not messages:
|
if not messages:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
copy_messages: List[MessageAndActionModel] = [
|
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
|
||||||
MessageAndActionModel(
|
|
||||||
msg.time,
|
|
||||||
msg.user_info.user_id,
|
|
||||||
msg.user_info.platform,
|
|
||||||
msg.user_info.user_nickname,
|
|
||||||
msg.user_info.user_cardname,
|
|
||||||
msg.processed_plain_text,
|
|
||||||
msg.display_message,
|
|
||||||
msg.chat_info.platform,
|
|
||||||
)
|
|
||||||
for msg in messages
|
|
||||||
]
|
|
||||||
|
|
||||||
if show_actions and copy_messages:
|
if show_actions and copy_messages:
|
||||||
# 获取所有消息的时间范围
|
# 获取所有消息的时间范围
|
||||||
@@ -924,22 +915,3 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set) # 将集合转换为列表返回
|
||||||
|
|
||||||
|
|
||||||
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
|
|
||||||
"""
|
|
||||||
将 DatabaseMessages 列表转换为 MessageAndActionModel 列表。
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
MessageAndActionModel(
|
|
||||||
time=msg.time,
|
|
||||||
user_id=msg.user_info.user_id,
|
|
||||||
user_platform=msg.user_info.platform,
|
|
||||||
user_nickname=msg.user_info.user_nickname,
|
|
||||||
user_cardname=msg.user_info.user_cardname,
|
|
||||||
processed_plain_text=msg.processed_plain_text,
|
|
||||||
display_message=msg.display_message,
|
|
||||||
chat_info_platform=msg.chat_info.platform,
|
|
||||||
)
|
|
||||||
for msg in message
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .database_data_model import DatabaseMessages
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageAndActionModel(BaseDataModel):
|
class MessageAndActionModel(BaseDataModel):
|
||||||
|
chat_id: str = field(default_factory=str)
|
||||||
time: float = field(default_factory=float)
|
time: float = field(default_factory=float)
|
||||||
user_id: str = field(default_factory=str)
|
user_id: str = field(default_factory=str)
|
||||||
user_platform: str = field(default_factory=str)
|
user_platform: str = field(default_factory=str)
|
||||||
@@ -15,3 +20,17 @@ class MessageAndActionModel(BaseDataModel):
|
|||||||
chat_info_platform: str = field(default_factory=str)
|
chat_info_platform: str = field(default_factory=str)
|
||||||
is_action_record: bool = field(default=False)
|
is_action_record: bool = field(default=False)
|
||||||
action_name: Optional[str] = None
|
action_name: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
|
||||||
|
return cls(
|
||||||
|
chat_id=message.chat_id,
|
||||||
|
time=message.time,
|
||||||
|
user_id=message.user_info.user_id,
|
||||||
|
user_platform=message.user_info.platform,
|
||||||
|
user_nickname=message.user_info.user_nickname,
|
||||||
|
user_cardname=message.user_info.user_cardname,
|
||||||
|
processed_plain_text=message.processed_plain_text,
|
||||||
|
display_message=message.display_message,
|
||||||
|
chat_info_platform=message.chat_info.platform,
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import traceback
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
from typing import List, Dict, Any
|
import asyncio
|
||||||
|
from typing import List, Dict, Any, TYPE_CHECKING
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
@@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
num_new_messages_since,
|
num_new_messages_since,
|
||||||
)
|
)
|
||||||
import asyncio
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("relationship_builder")
|
logger = get_logger("relationship_builder")
|
||||||
|
|
||||||
@@ -429,7 +432,7 @@ class RelationshipBuilder:
|
|||||||
if dropped_count > 0:
|
if dropped_count > 0:
|
||||||
logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
||||||
|
|
||||||
processed_messages = []
|
processed_messages: List["DatabaseMessages"] = []
|
||||||
|
|
||||||
# 对筛选后的消息段进行排序,确保时间顺序
|
# 对筛选后的消息段进行排序,确保时间顺序
|
||||||
segments_to_process.sort(key=lambda x: x["start_time"])
|
segments_to_process.sort(key=lambda x: x["start_time"])
|
||||||
@@ -449,17 +452,18 @@ class RelationshipBuilder:
|
|||||||
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
||||||
if processed_messages:
|
if processed_messages:
|
||||||
# 创建一个特殊的间隔消息
|
# 创建一个特殊的间隔消息
|
||||||
gap_message = {
|
gap_message = DatabaseMessages(
|
||||||
"time": start_time - 0.1, # 稍微早于段开始时间
|
time=start_time - 0.1,
|
||||||
"user_id": "system",
|
user_id="system",
|
||||||
"user_platform": "system",
|
user_platform="system",
|
||||||
"user_nickname": "系统",
|
user_nickname="系统",
|
||||||
"user_cardname": "",
|
user_cardname="",
|
||||||
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
||||||
"is_action_record": True,
|
is_action_record=True,
|
||||||
"chat_info_platform": segment_messages[0].chat_info.platform or "",
|
chat_info_platform=segment_messages[0].chat_info.platform or "",
|
||||||
"chat_id": chat_id,
|
chat_id=chat_id,
|
||||||
}
|
)
|
||||||
|
|
||||||
processed_messages.append(gap_message)
|
processed_messages.append(gap_message)
|
||||||
|
|
||||||
# 添加该段的所有消息
|
# 添加该段的所有消息
|
||||||
@@ -467,11 +471,11 @@ class RelationshipBuilder:
|
|||||||
|
|
||||||
if processed_messages:
|
if processed_messages:
|
||||||
# 按时间排序所有消息(包括间隔标识)
|
# 按时间排序所有消息(包括间隔标识)
|
||||||
processed_messages.sort(key=lambda x: x["time"])
|
processed_messages.sort(key=lambda x: x.time)
|
||||||
|
|
||||||
logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
|
logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
|
|
||||||
build_frequency = 0.3 * global_config.relationship.relation_frequency
|
build_frequency = 0.3 * global_config.relationship.relation_frequency
|
||||||
if random.random() < build_frequency:
|
if random.random() < build_frequency:
|
||||||
# 调用原有的更新方法
|
# 调用原有的更新方法
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ import traceback
|
|||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List, TYPE_CHECKING
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from .person_info import Person
|
from .person_info import Person
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
|
||||||
logger = get_logger("relation")
|
logger = get_logger("relation")
|
||||||
|
|
||||||
|
|
||||||
@@ -177,7 +179,7 @@ class RelationshipManager:
|
|||||||
|
|
||||||
return person
|
return person
|
||||||
|
|
||||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
|
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]):
|
||||||
"""更新用户印象
|
"""更新用户印象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -192,8 +194,6 @@ class RelationshipManager:
|
|||||||
# nickname = person.nickname
|
# nickname = person.nickname
|
||||||
know_times: float = person.know_times
|
know_times: float = person.know_times
|
||||||
|
|
||||||
user_messages = bot_engaged_messages
|
|
||||||
|
|
||||||
# 匿名化消息
|
# 匿名化消息
|
||||||
# 创建用户名称映射
|
# 创建用户名称映射
|
||||||
name_mapping = {}
|
name_mapping = {}
|
||||||
@@ -201,7 +201,7 @@ class RelationshipManager:
|
|||||||
user_count = 1
|
user_count = 1
|
||||||
|
|
||||||
# 遍历消息,构建映射
|
# 遍历消息,构建映射
|
||||||
for msg in user_messages:
|
for msg in bot_engaged_messages:
|
||||||
if msg.user_info.user_id == "system":
|
if msg.user_info.user_id == "system":
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@@ -233,7 +233,7 @@ class RelationshipManager:
|
|||||||
current_user = chr(ord(current_user) + 1)
|
current_user = chr(ord(current_user) + 1)
|
||||||
|
|
||||||
readable_messages = build_readable_messages(
|
readable_messages = build_readable_messages(
|
||||||
messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for original_name, mapped_name in name_mapping.items():
|
for original_name, mapped_name in name_mapping.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user