PFC 兼容新版数据模型
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
||||
logger = get_logger("pfc_action_planner")
|
||||
@@ -271,19 +273,17 @@ class ActionPlanner:
|
||||
# 获取聊天历史记录 (chat_history_text)
|
||||
try:
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if not chat_history_text:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
chat_history_text = observation_info.chat_history_str or "还没有聊天记录。\n"
|
||||
else:
|
||||
chat_history_text = "还没有聊天记录。\n"
|
||||
|
||||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import time
|
||||
import asyncio
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.services.message_service import build_readable_messages, get_messages_before_time_in_chat
|
||||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
|
||||
# from src.config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
@@ -83,7 +84,7 @@ class Conversation:
|
||||
raise
|
||||
try:
|
||||
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat( #
|
||||
initial_messages = get_messages_before_time_in_chat(
|
||||
chat_id=self.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
@@ -95,22 +96,23 @@ class Conversation:
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将 DatabaseMessages 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
# 将 SessionMessage 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
initial_messages_dict: list[dict] = []
|
||||
for msg in initial_messages:
|
||||
user_info = msg.message_info.user_info
|
||||
msg_dict = {
|
||||
"message_id": msg.message_id,
|
||||
"time": msg.time,
|
||||
"chat_id": msg.chat_id,
|
||||
"time": msg.timestamp.timestamp(),
|
||||
"chat_id": msg.session_id,
|
||||
"processed_plain_text": msg.processed_plain_text,
|
||||
"display_message": msg.display_message,
|
||||
"is_mentioned": msg.is_mentioned,
|
||||
"is_command": msg.is_command,
|
||||
"user_info": {
|
||||
"user_id": msg.user_info.user_id if msg.user_info else "",
|
||||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
"user_id": user_info.user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"user_cardname": user_info.user_cardname,
|
||||
"platform": msg.platform,
|
||||
},
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
@@ -1,40 +1,52 @@
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.mai_message_data_model import MessageInfo, UserInfo as MaiUserInfo
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_logger("observation_info")
|
||||
|
||||
|
||||
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
|
||||
"""Convert PFC dict format to DatabaseMessages object
|
||||
def dict_to_session_message(msg_dict: Dict[str, Any]) -> SessionMessage:
|
||||
"""Convert PFC dict format to SessionMessage object.
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
Returns:
|
||||
DatabaseMessages object compatible with build_readable_messages()
|
||||
SessionMessage object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
|
||||
return DatabaseMessages(
|
||||
timestamp = msg_dict.get("time", 0.0)
|
||||
platform = user_info_dict.get("platform", "")
|
||||
message = SessionMessage(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", 0.0),
|
||||
chat_id=msg_dict.get("chat_id", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
display_message=msg_dict.get("display_message", ""),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False),
|
||||
is_command=msg_dict.get("is_command", False),
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
user_platform=user_info_dict.get("platform", ""),
|
||||
timestamp=datetime.fromtimestamp(timestamp),
|
||||
platform=platform,
|
||||
)
|
||||
message.message_info = MessageInfo(
|
||||
user_info=MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
)
|
||||
message.session_id = msg_dict.get("chat_id", "")
|
||||
message.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
message.display_message = msg_dict.get("display_message", "")
|
||||
message.is_mentioned = msg_dict.get("is_mentioned", False)
|
||||
message.is_command = msg_dict.get("is_command", False)
|
||||
message.initialized = True
|
||||
return message
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
@@ -393,10 +405,10 @@ class ObservationInfo:
|
||||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in history_slice_for_str]
|
||||
# Convert dict format to SessionMessage objects.
|
||||
session_messages = [dict_to_session_message(m) for m in history_slice_for_str]
|
||||
self.chat_history_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
|
||||
@@ -6,8 +6,9 @@ import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -103,9 +104,9 @@ class GoalAnalyzer:
|
||||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
@@ -238,9 +239,9 @@ class GoalAnalyzer:
|
||||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
db_messages = [dict_to_database_message(m) for m in messages]
|
||||
session_messages = [dict_to_session_message(m) for m in messages]
|
||||
chat_history_text = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
@@ -5,9 +5,10 @@ from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.services.message_service import build_readable_messages
|
||||
|
||||
from .observation_info import ObservationInfo, dict_to_session_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
logger = get_logger("reply_generator")
|
||||
|
||||
@@ -163,7 +164,7 @@ class ReplyGenerator:
|
||||
knowledge = knowledge_item.get("knowledge", "无知识内容")
|
||||
source = knowledge_item.get("source", "未知来源")
|
||||
# 只取知识内容的前 2000 个字
|
||||
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_snippet = f"{knowledge[:2000]}..." if len(knowledge) > 2000 else knowledge
|
||||
knowledge_info_str += (
|
||||
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
|
||||
)
|
||||
@@ -186,9 +187,9 @@ class ReplyGenerator:
|
||||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
session_messages = [dict_to_session_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
session_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
|
||||
@@ -14,12 +14,11 @@ from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import model_config, global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.services import message_service as message_api
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.person_info.person_info import Person
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
@@ -34,7 +33,7 @@ HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorize
|
||||
class MessageBatch:
|
||||
"""消息批次(用于触发话题检查的原始消息累积)"""
|
||||
|
||||
messages: List[DatabaseMessages]
|
||||
messages: List[SessionMessage]
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
@@ -101,7 +100,7 @@ class ChatHistorySummarizer:
|
||||
def _get_chat_display_name(self) -> str:
|
||||
"""获取聊天显示名称"""
|
||||
try:
|
||||
chat_name = _chat_manager.get_session_name(self.chat_id)
|
||||
chat_name = _chat_manager.get_session_name(self.session_id)
|
||||
if chat_name:
|
||||
return chat_name
|
||||
# 如果获取失败,使用简化的chat_id显示
|
||||
@@ -268,7 +267,7 @@ class ChatHistorySummarizer:
|
||||
# 创建新批次
|
||||
self.current_batch = MessageBatch(
|
||||
messages=new_messages,
|
||||
start_time=new_messages[0].time if new_messages else current_time,
|
||||
start_time=new_messages[0].timestamp.timestamp() if new_messages else current_time,
|
||||
end_time=current_time,
|
||||
)
|
||||
logger.debug(f"{self.log_prefix} 新建聊天检查批次: {len(new_messages)} 条消息")
|
||||
@@ -340,7 +339,7 @@ class ChatHistorySummarizer:
|
||||
self.last_topic_check_time = current_time
|
||||
self._persist_topic_cache()
|
||||
|
||||
async def _run_topic_check_and_update_cache(self, messages: List[DatabaseMessages]):
|
||||
async def _run_topic_check_and_update_cache(self, messages: List[SessionMessage]):
|
||||
"""
|
||||
执行一次“话题检查”:
|
||||
1. 首先确认这段消息里是否有 Bot 发言,没有则直接丢弃本次批次;
|
||||
@@ -355,8 +354,8 @@ class ChatHistorySummarizer:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
start_time = messages[0].time
|
||||
end_time = messages[-1].time
|
||||
start_time = messages[0].timestamp.timestamp()
|
||||
end_time = messages[-1].timestamp.timestamp()
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 开始话题检查 | 消息数: {len(messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
|
||||
@@ -365,13 +364,9 @@ class ChatHistorySummarizer:
|
||||
# 1. 检查当前批次内是否有 bot 发言(只检查当前批次,不往前推)
|
||||
# 原因:我们要记录的是 bot 参与过的对话片段,如果当前批次内 bot 没有发言,
|
||||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
has_bot_message = False
|
||||
|
||||
for msg in messages:
|
||||
# 使用统一的 is_bot_self 函数判断是否是机器人自己(支持多平台,包括 WebUI)
|
||||
if is_bot_self(msg.user_info.platform, msg.user_info.user_id):
|
||||
has_bot_message = True
|
||||
break
|
||||
has_bot_message = any(
|
||||
is_bot_self(msg.platform, msg.message_info.user_info.user_id) for msg in messages
|
||||
)
|
||||
|
||||
if not has_bot_message:
|
||||
logger.info(
|
||||
@@ -575,7 +570,7 @@ class ChatHistorySummarizer:
|
||||
return topic_mapping
|
||||
|
||||
def _build_numbered_messages_for_llm(
|
||||
self, messages: List[DatabaseMessages]
|
||||
self, messages: List[SessionMessage]
|
||||
) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]:
|
||||
"""
|
||||
将消息转为带编号的字符串,供 LLM 选择使用。
|
||||
@@ -594,7 +589,7 @@ class ChatHistorySummarizer:
|
||||
for idx, msg in enumerate(messages, start=1):
|
||||
# 使用 build_readable_messages 生成可读文本
|
||||
try:
|
||||
text = build_readable_messages(
|
||||
text = message_api.build_readable_messages(
|
||||
messages=[msg],
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
@@ -609,12 +604,8 @@ class ChatHistorySummarizer:
|
||||
# 获取发言人昵称
|
||||
participants: Set[str] = set()
|
||||
try:
|
||||
platform = (
|
||||
getattr(msg, "user_platform", None)
|
||||
or (msg.user_info.platform if msg.user_info else None)
|
||||
or msg.chat_info.platform
|
||||
)
|
||||
user_id = msg.user_info.user_id if msg.user_info else None
|
||||
platform = msg.platform
|
||||
user_id = msg.message_info.user_info.user_id
|
||||
if platform and user_id:
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
if person.person_name:
|
||||
|
||||
Reference in New Issue
Block a user