修复relationship_build
This commit is contained in:
@@ -3,7 +3,8 @@ import traceback
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
from typing import List, Dict, Any
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, TYPE_CHECKING
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
@@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
num_new_messages_since,
|
||||
)
|
||||
import asyncio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("relationship_builder")
|
||||
|
||||
@@ -429,7 +432,7 @@ class RelationshipBuilder:
|
||||
if dropped_count > 0:
|
||||
logger.debug(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
||||
|
||||
processed_messages = []
|
||||
processed_messages: List["DatabaseMessages"] = []
|
||||
|
||||
# 对筛选后的消息段进行排序,确保时间顺序
|
||||
segments_to_process.sort(key=lambda x: x["start_time"])
|
||||
@@ -449,17 +452,18 @@ class RelationshipBuilder:
|
||||
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
||||
if processed_messages:
|
||||
# 创建一个特殊的间隔消息
|
||||
gap_message = {
|
||||
"time": start_time - 0.1, # 稍微早于段开始时间
|
||||
"user_id": "system",
|
||||
"user_platform": "system",
|
||||
"user_nickname": "系统",
|
||||
"user_cardname": "",
|
||||
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
||||
"is_action_record": True,
|
||||
"chat_info_platform": segment_messages[0].chat_info.platform or "",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
gap_message = DatabaseMessages(
|
||||
time=start_time - 0.1,
|
||||
user_id="system",
|
||||
user_platform="system",
|
||||
user_nickname="系统",
|
||||
user_cardname="",
|
||||
display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...",
|
||||
is_action_record=True,
|
||||
chat_info_platform=segment_messages[0].chat_info.platform or "",
|
||||
chat_id=chat_id,
|
||||
)
|
||||
|
||||
processed_messages.append(gap_message)
|
||||
|
||||
# 添加该段的所有消息
|
||||
@@ -467,11 +471,11 @@ class RelationshipBuilder:
|
||||
|
||||
if processed_messages:
|
||||
# 按时间排序所有消息(包括间隔标识)
|
||||
processed_messages.sort(key=lambda x: x["time"])
|
||||
processed_messages.sort(key=lambda x: x.time)
|
||||
|
||||
logger.debug(f"为 {person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
|
||||
relationship_manager = get_relationship_manager()
|
||||
|
||||
|
||||
build_frequency = 0.3 * global_config.relationship.relation_frequency
|
||||
if random.random() < build_frequency:
|
||||
# 调用原有的更新方法
|
||||
|
||||
@@ -3,16 +3,18 @@ import traceback
|
||||
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import List, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from .person_info import Person
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
||||
|
||||
@@ -177,7 +179,7 @@ class RelationshipManager:
|
||||
|
||||
return person
|
||||
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
|
||||
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]):
|
||||
"""更新用户印象
|
||||
|
||||
Args:
|
||||
@@ -192,8 +194,6 @@ class RelationshipManager:
|
||||
# nickname = person.nickname
|
||||
know_times: float = person.know_times
|
||||
|
||||
user_messages = bot_engaged_messages
|
||||
|
||||
# 匿名化消息
|
||||
# 创建用户名称映射
|
||||
name_mapping = {}
|
||||
@@ -201,7 +201,7 @@ class RelationshipManager:
|
||||
user_count = 1
|
||||
|
||||
# 遍历消息,构建映射
|
||||
for msg in user_messages:
|
||||
for msg in bot_engaged_messages:
|
||||
if msg.user_info.user_id == "system":
|
||||
continue
|
||||
try:
|
||||
@@ -233,7 +233,7 @@ class RelationshipManager:
|
||||
current_user = chr(ord(current_user) + 1)
|
||||
|
||||
readable_messages = build_readable_messages(
|
||||
messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
||||
messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
|
||||
)
|
||||
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
|
||||
Reference in New Issue
Block a user