import time import asyncio from typing import List, Any, Optional from collections import OrderedDict from dataclasses import dataclass from src.common.logger import get_logger from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.bw_learner.expression_learner import expression_learner_manager from src.bw_learner.jargon_miner import miner_manager from src.person_info.person_info import Person logger = get_logger("bw_learner") @dataclass class PersonInfo: """参与聊天的人物信息""" user_id: str user_platform: str user_nickname: str user_cardname: Optional[str] person_name: str last_seen_time: float # 最后发言时间 def get_unique_key(self) -> str: """获取唯一标识(用于去重)""" return f"{self.user_platform}:{self.user_id}" class MessageRecorder: """ 统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner """ def __init__(self, chat_id: str) -> None: self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id # 维护每个chat的上次提取时间 self.last_extraction_time: float = time.time() # 提取锁,防止并发执行 self._extraction_lock = asyncio.Lock() # 维护参与该chat_id的人物列表(最多30个,使用OrderedDict保持插入顺序) # key: f"{platform}:{user_id}", value: PersonInfo self._person_list: OrderedDict[str, PersonInfo] = OrderedDict() self._max_person_count = 30 # 获取 expression 和 jargon 的配置参数 self._init_parameters() # 获取 expression_learner 和 jargon_miner 实例 self.expression_learner = expression_learner_manager.get_expression_learner(chat_id) self.jargon_miner = miner_manager.get_miner(chat_id) def _init_parameters(self) -> None: """初始化提取参数""" # 获取 expression 配置 _, self.enable_expression_learning, self.enable_jargon_learning = ( global_config.expression.get_expression_config_for_chat(self.chat_id) ) self.min_messages_for_extraction = 30 self.min_extraction_interval = 60 logger.debug( f"MessageRecorder 初始化: chat_id={self.chat_id}, " f"min_messages={self.min_messages_for_extraction}, " f"min_interval={self.min_extraction_interval}" ) def should_trigger_extraction(self) -> bool: """ 检查是否应该触发消息提取 Returns: bool: 是否应该触发提取 """ # 检查时间间隔 time_diff = time.time() - self.last_extraction_time if time_diff < self.min_extraction_interval: return False # 检查消息数量 recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_extraction_time, timestamp_end=time.time(), ) if not recent_messages or len(recent_messages) < self.min_messages_for_extraction: return False return True async def extract_and_distribute(self) -> None: """ 提取消息并分发给 expression_learner 和 jargon_miner """ # 使用异步锁防止并发执行 async with self._extraction_lock: # 在锁内检查,避免并发触发 if not self.should_trigger_extraction(): return # 检查 chat_stream 是否存在 if not self.chat_stream: return # 记录本次提取的时间窗口,避免重复提取 extraction_start_time = self.last_extraction_time extraction_end_time = time.time() # 立即更新提取时间,防止并发触发 self.last_extraction_time = extraction_end_time try: logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发") # 拉取提取窗口内的消息 messages = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=extraction_start_time, timestamp_end=extraction_end_time, ) if not messages: logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取") return # 按时间排序,确保顺序一致 messages = sorted(messages, key=lambda msg: msg.time or 0) # 更新参与聊天的人物列表 self._update_person_list(messages) logger.info(f"聊天流 {self.chat_name} 的人物列表: {self._person_list}") logger.info( f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息," f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" ) # 分别触发 expression_learner 和 jargon_miner 的处理 # 传递提取的消息,避免它们重复获取 # 触发 expression 学习(如果启用) if self.enable_expression_learning: asyncio.create_task( self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages) ) # 触发 jargon 提取(如果启用),传递消息 # if self.enable_jargon_learning: # asyncio.create_task( # self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) # ) except Exception as e: logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") import traceback traceback.print_exc() # 即使失败也保持时间戳更新,避免频繁重试 async def _trigger_expression_learning( self, timestamp_start: float, timestamp_end: float, messages: List[Any] ) -> None: """ 触发 expression 学习,使用指定的消息列表 Args: timestamp_start: 开始时间戳 timestamp_end: 结束时间戳 messages: 消息列表 """ try: # 传递消息和过滤函数给 ExpressionLearner learnt_style = await self.expression_learner.learn_and_store( messages=messages, person_name_filter=self.contains_person_name ) if learnt_style: logger.info(f"聊天流 {self.chat_name} 表达学习完成") else: logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果") except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") import traceback traceback.print_exc() async def _trigger_jargon_extraction( self, timestamp_start: float, timestamp_end: float, messages: List[Any] ) -> None: """ 触发 jargon 提取,使用指定的消息列表 Args: timestamp_start: 开始时间戳 timestamp_end: 结束时间戳 messages: 消息列表 """ try: # 传递消息和过滤函数给 JargonMiner await self.jargon_miner.run_once( messages=messages, person_name_filter=self.contains_person_name ) except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") import traceback traceback.print_exc() def _update_person_list(self, messages: List[Any]) -> None: """ 从消息中提取人物信息并更新人物列表 Args: messages: 消息列表 """ for msg in messages: # 获取消息发送者信息 # 消息对象可能是 DatabaseMessages,它有 user_info 属性 if hasattr(msg, 'user_info'): # DatabaseMessages 类型 user_info = msg.user_info user_id = getattr(user_info, 'user_id', None) or '' user_platform = getattr(user_info, 'platform', None) or '' user_nickname = getattr(user_info, 'user_nickname', None) or '' user_cardname = getattr(user_info, 'user_cardname', None) else: # 直接属性访问 user_id = getattr(msg, 'user_id', None) or '' user_platform = getattr(msg, 'user_platform', None) or '' user_nickname = getattr(msg, 'user_nickname', None) or '' user_cardname = getattr(msg, 'user_cardname', None) msg_time = getattr(msg, 'time', time.time()) # 检查必要信息 if not user_id or not user_platform: continue # 获取 person_name try: person = Person(platform=user_platform, user_id=str(user_id)) person_name = person.person_name or user_nickname or (user_cardname if user_cardname else "未知用户") except Exception as e: logger.info(f"获取person_name失败: {e}, 使用nickname") person_name = user_nickname or (user_cardname if user_cardname else "未知用户") # 生成唯一key unique_key = f"{user_platform}:{user_id}" # 如果已存在,更新最后发言时间 if unique_key in self._person_list: self._person_list[unique_key].last_seen_time = msg_time # 移动到末尾(表示最近活跃) self._person_list.move_to_end(unique_key) else: # 如果超过最大数量,移除最早的(最前面的) if len(self._person_list) >= self._max_person_count: oldest_key = next(iter(self._person_list)) del self._person_list[oldest_key] logger.info(f"人物列表已满,移除最早的人物: {oldest_key}") # 添加新人物 person_info = PersonInfo( user_id=str(user_id), user_platform=user_platform, user_nickname=user_nickname or "", user_cardname=user_cardname, person_name=person_name, last_seen_time=msg_time ) self._person_list[unique_key] = person_info logger.info(f"添加新人物到列表: {unique_key}, person_name={person_name}") def contains_person_name(self, content: str) -> bool: """ 检查内容是否包含任何参与聊天的人物的名称或昵称 Args: content: 要检查的内容 Returns: bool: 如果包含任何人物名称或昵称,返回True """ if not content or not self._person_list: return False content_lower = content.strip().lower() if not content_lower: return False # 检查所有人物 for person_info in self._person_list.values(): # 检查 person_name if person_info.person_name: person_name_lower = person_info.person_name.strip().lower() if person_name_lower and person_name_lower in content_lower: logger.debug(f"内容包含person_name: {person_info.person_name} in {content}") return True # 检查 user_nickname if person_info.user_nickname: nickname_lower = person_info.user_nickname.strip().lower() if nickname_lower and nickname_lower in content_lower: logger.debug(f"内容包含nickname: {person_info.user_nickname} in {content}") return True # 检查 user_cardname(群昵称) if person_info.user_cardname: cardname_lower = person_info.user_cardname.strip().lower() if cardname_lower and cardname_lower in content_lower: logger.debug(f"内容包含cardname: {person_info.user_cardname} in {content}") return True return False class MessageRecorderManager: """MessageRecorder 管理器""" def __init__(self) -> None: self._recorders: dict[str, MessageRecorder] = {} def get_recorder(self, chat_id: str) -> MessageRecorder: """获取或创建指定 chat_id 的 MessageRecorder""" if chat_id not in self._recorders: self._recorders[chat_id] = MessageRecorder(chat_id) return self._recorders[chat_id] # 全局管理器实例 recorder_manager = MessageRecorderManager() async def extract_and_distribute_messages(chat_id: str) -> None: """ 统一的消息提取和分发入口函数 Args: chat_id: 聊天流ID """ recorder = recorder_manager.get_recorder(chat_id) await recorder.extract_and_distribute()