PFC 兼容新版数据模型
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
@@ -271,19 +273,17 @@ class ActionPlanner:
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
chat_history_text = observation_info.chat_history_str or "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# from src.config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
@@ -83,7 +84,7 @@ class Conversation:
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
initial_messages = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
@@ -95,22 +96,23 @@ class Conversation:
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将 DatabaseMessages 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
# 将 SessionMessage 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
initial_messages_dict: list[dict] = []
|
||||
for msg in initial_messages:
|
||||
user_info = msg.message_info.user_info
|
||||
msg_dict = {
|
||||
"message_id": msg.message_id,
|
||||
"time": msg.time,
|
||||
"chat_id": msg.chat_id,
|
||||
"time": msg.timestamp.timestamp(),
|
||||
"chat_id": msg.session_id,
|
||||
"processed_plain_text": msg.processed_plain_text,
|
||||
"display_message": msg.display_message,
|
||||
"is_mentioned": msg.is_mentioned,
|
||||
"is_command": msg.is_command,
|
||||
"user_info": {
|
||||
"user_id": msg.user_info.user_id if msg.user_info else "",
|
||||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
"user_id": user_info.user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"user_cardname": user_info.user_cardname,
|
||||
"platform": msg.platform,
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
@@ -1,40 +1,52 @@
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo as MaiUserInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
||||
|
||||
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
|
||||
"""Convert PFC dict format to DatabaseMessages object
|
||||
def dict_to_session_message(msg_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""Convert PFC dict format to SessionMessage object.
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
Returns:
|
||||
DatabaseMessages object compatible with build_readable_messages()
|
||||
SessionMessage object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
|
||||
return DatabaseMessages(
|
||||
timestamp = msg_dict.get("time", 0.0)
|
||||
platform = user_info_dict.get("platform", "")
|
||||
message = SessionMessage(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", 0.0),
|
||||
chat_id=msg_dict.get("chat_id", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
display_message=msg_dict.get("display_message", ""),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False),
|
||||
is_command=msg_dict.get("is_command", False),
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
user_platform=user_info_dict.get("platform", ""),
|
||||
timestamp=datetime.fromtimestamp(timestamp),
|
||||
platform=platform,
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
)
|
||||
message.session_id = msg_dict.get("chat_id", "")
|
||||
message.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
message.display_message = msg_dict.get("display_message", "")
|
||||
message.is_mentioned = msg_dict.get("is_mentioned", False)
|
||||
message.is_command = msg_dict.get("is_command", False)
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
@@ -393,10 +405,10 @@ class ObservationInfo:
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in history_slice_for_str]
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in history_slice_for_str]
|
||||
self.chat_history_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
|
||||
@@ -6,8 +6,9 @@ import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -103,9 +104,9 @@ class GoalAnalyzer:
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
@@ -238,9 +239,9 @@ class GoalAnalyzer:
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
db_messages = [dict_to_database_message(m) for m in messages]
|
||||
session_messages = [dict_to_session_message(m) for m in messages]
|
||||
chat_history_text = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
@@ -5,9 +5,10 @@ from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_logger("reply_generator")
|
||||
|
||||
@@ -163,7 +164,7 @@ class ReplyGenerator:
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_snippet = f"{knowledge[:2000]}..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
@@ -186,9 +187,9 @@ class ReplyGenerator:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
Reference in New Issue
Block a user