diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index 37ac5a68..ba94d231 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple import asyncio import difflib import json +import re from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config @@ -14,12 +15,13 @@ from src.common.database.database_model import Expression from src.common.database.database import get_db_session from src.common.data_models.expression_data_model import MaiExpression from src.common.utils.utils_message import MessageUtils +from src.common.utils.system_utils import is_bot_self from .expression_utils import check_expression_suitability, parse_expression_response if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage - from .jargon_miner import JargonMiner + from .jargon_miner import JargonMiner, JargonEntry logger = get_logger("expressor") @@ -53,49 +55,278 @@ class ExpressionLearner: if not self._messages_cache: logger.debug("没有消息可供学习,跳过学习过程") return + + # 构建可读消息 readable_message, _, _ = await MessageUtils.build_readable_message( self._messages_cache, anonymize=True, show_lineno=True, extract_pictures=True, + replace_bot_name=True, + target_bot_name="SELF", ) - self._messages_cache.clear() # 学习后清空缓存 + + # 准备提示词 prompt_template = prompt_manager.get_prompt("learn_style") prompt_template.add_context("bot_name", global_config.bot.nickname) prompt_template.add_context("chat_str", readable_message) - prompt = await prompt_manager.render_prompt(prompt_template) + # 调用 LLM 学习表达方式 try: response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3) except Exception as e: - logger.error(f"学习表达方式失败,模型生成出错: {e}") - return None + logger.error(f"学习表达方式失败,模型生成出错:{e}") + return # 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号) expressions: List[Tuple[str, str, str]] jargon_entries: List[Tuple[str, str]] # (content, source_id) expressions, jargon_entries = parse_expression_response(response) - # TODO: 完成学习 - # TODO: 从缓存检查 jargon 是否出现在 message 中 - # TODO: 检查表达方式/黑话数量 - # TODO: 处理黑话条目 - # TODO: 过滤 - # TODO: 存储 - + # 从缓存中检查 jargon 是否出现在 messages 中 + if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner): + # 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过) + existing_contents = {content for content, _ in jargon_entries} + for content, source_id in cached_jargon_entries: + if content not in existing_contents: + jargon_entries.append((content, source_id)) + existing_contents.add(content) + logger.info(f"从缓存中检查到黑话:{content}") + + # 检查表达方式数量,如果超过 20 个则放弃本次表达学习 + if len(expressions) > 20: + logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习") + expressions = [] + + # 检查黑话数量,如果超过 30 个则放弃本次黑话学习 + if len(jargon_entries) > 30: + logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习") + jargon_entries = [] + + # 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话) + # TODO: 检测是否开启了 + if jargon_entries: + await self._process_jargon_entries(jargon_entries, jargon_miner) + + # 如果没有表达方式,直接返回 + if not expressions: + logger.info("解析后没有可用的表达方式") + return + + logger.info(f"学习的 expressions: {expressions}") + logger.info(f"学习的 jargon_entries: {jargon_entries}") + + # 过滤表达方式,根据 source_id 溯源并应用各种过滤规则 + learnt_expressions = self._filter_expressions(expressions) + + if not learnt_expressions: + logger.info("没有学习到表达风格") + return + + # 展示学到的表达方式 + learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions) + logger.info(f"在 {self.session_id} 学习到表达风格:\n{learnt_expressions_str}") + + # 存储到数据库 Expression 表 + for situation, style in learnt_expressions: + await self._upsert_expression_to_db(situation, style) + # ====== 黑话相关 ====== - def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None): + def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]: + """ + 检查缓存中的 jargon 是否出现在 messages 中 + + Args: + jargon_miner: JargonMiner 实例,用于获取缓存的黑话 + + Returns: + List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id) + """ if not jargon_miner: return [] - # 获取缓存的所有jargon实例 + + # 获取缓存的所有 jargon 实例 cached_jargons = jargon_miner.get_cached_jargons() if not cached_jargons: return [] + matched_entries: List[Tuple[str, str]] = [] - + for i, msg in enumerate(self._messages_cache): - if + # 跳过机器人自己的消息 + if is_bot_self(msg.message_info.user_info.user_id, msg.platform): + continue + + # 获取消息文本 + msg_text = (msg.processed_plain_text or "").strip() + + if not msg_text: + continue + + # 检查每个缓存中的 jargon 是否出现在消息文本中 + for jargon in cached_jargons: + if not jargon or not jargon.strip(): + continue + + jargon_content = jargon.strip() + + # 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑) + pattern = re.escape(jargon_content) + # 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界 + if re.search(r"[\u4e00-\u9fff]", jargon_content): + # 包含中文,使用更宽松的匹配 + search_pattern = pattern + else: + # 纯英文/数字,使用单词边界 + search_pattern = r"\b" + pattern + r"\b" + + if re.search(search_pattern, msg_text, re.IGNORECASE): + # 找到匹配,构建条目(source_id 从 1 开始,因为 build_readable_message 的编号从 1 开始) + source_id = str(i + 1) + matched_entries.append((jargon_content, source_id)) + + return matched_entries + + async def _process_jargon_entries( + self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None + ): + """ + 处理从 expression learner 提取的黑话条目,路由到 jargon_miner + + Args: + jargon_entries: 黑话条目列表,每个元素是 (content, source_id) + jargon_miner: JargonMiner 实例 + """ + if not jargon_entries or not self._messages_cache: + return + + if not jargon_miner: + logger.warning("缺少 JargonMiner 实例,无法处理黑话条目") + return + + # 构建黑话条目格式 + entries: List["JargonEntry"] = [] + + for content, source_id in jargon_entries: + content = content.strip() + if not content: + continue + + # 过滤掉包含 SELF 的黑话,不学习 + if "SELF" in content: + logger.info(f"跳过包含 SELF 的黑话:{content}") + continue + + # TODO: 多平台兼容 + # 检查是否包含机器人名称 + bot_nickname = global_config.bot.nickname + if bot_nickname and bot_nickname in content: + logger.info(f"跳过包含机器人昵称的黑话:{content}") + continue + + # 解析 source_id + if not source_id.isdigit(): + logger.warning(f"黑话条目 source_id 无效:content={content}, source_id={source_id}") + continue + + # build_readable_message 的编号从 1 开始 + line_index = int(source_id) - 1 + if line_index < 0 or line_index >= len(self._messages_cache): + logger.warning(f"黑话条目 source_id 超出范围:content={content}, source_id={source_id}") + continue + + # 检查是否是机器人自己的消息 + target_msg = self._messages_cache[line_index] + if is_bot_self(target_msg.message_info.user_info.user_id, target_msg.platform): + logger.info(f"跳过引用机器人自身消息的黑话:content={content}, source_id={source_id}") + continue + + # 构建上下文段落(取前后各 3 条消息) + start_idx = max(0, line_index - 3) + end_idx = min(len(self._messages_cache), line_index + 4) + context_msgs = self._messages_cache[start_idx:end_idx] + + context_paragraph = "\n".join( + [f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)] + ) + + if not context_paragraph: + logger.warning(f"黑话条目上下文为空:content={content}, source_id={source_id}") + continue + + entries.append({"content": content, "raw_content": {context_paragraph}}) # type: ignore + + if not entries: + return + + await jargon_miner.process_extracted_entries(entries) + logger.info(f"成功处理 {len(entries)} 个黑话条目") + + # ====== 过滤方法 ====== + def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]: + """ + 过滤表达方式,移除不符合条件的条目 + + Args: + expressions: 表达方式列表,每个元素是 (situation, style, source_id) + + Returns: + 过滤后的表达方式列表,每个元素是 (situation, style) + """ + filtered_expressions: List[Tuple[str, str]] = [] + + # 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达) + # TODO: 完善这里的机器人名称检测逻辑(考虑别名、不同平台的名称等) + banned_names: set[str] = set() + bot_nickname = global_config.bot.nickname + if bot_nickname: + banned_names.add(bot_nickname) + alias_names = global_config.bot.alias_names or [] + for alias in alias_names: + if alias_stripped := alias.strip(): + banned_names.add(alias_stripped) + banned_casefold = {name.casefold() for name in banned_names if name} + + for situation, style, source_id in expressions: + source_id_str = source_id.strip() + if not source_id_str.isdigit(): + continue # 无效的来源行编号,跳过 + line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始 + if line_index < 0 or line_index >= len(self._messages_cache): + continue # 超出范围,跳过 + # 当前行的原始消息 + current_msg = self._messages_cache[line_index] + # 过滤掉从 bot 自己发言中提取到的表达方式 + if is_bot_self(current_msg.message_info.user_info.user_id, current_msg.platform): + continue + # 过滤掉无上下文的表达方式 + context = (current_msg.processed_plain_text or "").strip() + if not context: + continue + # 过滤掉包含 SELF 的内容(不学习) + if "SELF" in situation or "SELF" in style or "SELF" in context: + logger.info(f"跳过包含 SELF 的表达方式:situation={situation}, style={style}, source_id={source_id}") + continue + # 过滤掉 style 与机器人名称/昵称重复的表达 + normalized_style = (style or "").strip() + if normalized_style and normalized_style.casefold() in banned_casefold: + logger.debug( + f"跳过 style 与机器人名称重复的表达方式:situation={situation}, style={style}, source_id={source_id}" + ) + continue + # 过滤掉包含 "[表情" 的内容 + if "[表情包" in situation or "[表情包" in style or "[表情包" in context: + logger.info(f"跳过包含表情标记的表达方式:situation={situation}, style={style}, source_id={source_id}") + continue + # 过滤掉包含 "[图片" 的内容 + if "[图片" in situation or "[图片" in style or "[图片" in context: + logger.info(f"跳过包含图片标记的表达方式:situation={situation}, style={style}, source_id={source_id}") + continue + + filtered_expressions.append((situation, style)) + + return filtered_expressions # ====== DB 操作相关 ====== async def _upsert_expression_to_db(self, situation: str, style: str): diff --git a/src/bw_learner/jargon_miner.py b/src/bw_learner/jargon_miner.py index 9f4270ad..2fbf8a2e 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -197,7 +197,7 @@ class JargonMiner: logger.info(f"[{self.session_name}]{content} 不是黑话") async def process_extracted_entries( - self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] + self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None ): """ 处理已提取的黑话条目(从 expression_learner 路由过来的) diff --git a/src/common/utils/system_utils.py b/src/common/utils/system_utils.py index d956cb58..bfec3933 100644 --- a/src/common/utils/system_utils.py +++ b/src/common/utils/system_utils.py @@ -1,7 +1,7 @@ -# TODO: 这个函数的实现非常临时,后续需要替换为更完善的实现,比如直接从配置文件中读取机器人自己的ID,或者通过API获取机器人自己的信息等 +# TODO: 这个函数的实现非常临时,后续需要替换为更完善的实现,比如直接从配置文件中读取机器人自己的 ID,或者通过 API 获取机器人自己的信息等 def is_bot_self(user_id: str, platform: str) -> bool: """ - 判断用户ID是否是机器人自己 + 判断用户 ID 是否是机器人自己 临时方法,后续会替换为更完善的实现 """ diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index 8c54e3ad..cd46c37e 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -223,7 +223,7 @@ class MessageUtils: processed_plain_texts.append("图片信息和表情信息:") processed_plain_texts.extend(f"[图片{img_id}: {desc}]" for img_id, desc in img_map.values()) processed_plain_texts.append("") # 图片和表情之间添加一个换行,避免连在一起 - processed_plain_texts.extend(f"[表情{emoji_id}: {desc}]" for emoji_id, desc in emoji_map.values()) + processed_plain_texts.extend(f"[表情包{emoji_id}: {desc}]" for emoji_id, desc in emoji_map.values()) processed_plain_texts.extend(("", "聊天记录信息:")) # 获取动作记录文本列表