diff --git a/AGENTS.md b/AGENTS.md index e3f6b96f..4fac1284 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -10,6 +10,7 @@ - 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。 3. 标准库和第三方库的导入应该放在本地模块导入的前面。 4. 各个导入块之间应该使用一个空行进行分隔。 +5. 对于现有的代码,如果导入顺序不符合上述规范,在重构代码时应该调整导入顺序以符合规范。 # 代码规范 ## 注释规范 diff --git a/src/bw_learner/expression_learner_old.py b/src/bw_learner/expression_learner_old.py deleted file mode 100644 index a18fd985..00000000 --- a/src/bw_learner/expression_learner_old.py +++ /dev/null @@ -1,596 +0,0 @@ -import time -import json -import os -import re -import asyncio -from typing import List, Optional, Tuple, Any, Dict -from src.common.logger import get_logger -from src.common.database.database_model import Expression -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.utils.chat_message_builder import ( - build_anonymous_messages, -) -from src.prompt.prompt_manager import prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager -from src.bw_learner.learner_utils_old import ( - filter_message_content, - is_bot_message, - build_context_paragraph, - contains_bot_self_name, - calculate_similarity, - parse_expression_response, -) -from src.bw_learner.jargon_miner_old import miner_manager -from src.bw_learner.expression_auto_check_task import ( - single_expression_check, -) - - -# MAX_EXPRESSION_COUNT = 300 - -logger = get_logger("expressor") - - -class ExpressionLearner: - def __init__(self, chat_id: str) -> None: - self.express_learn_model: LLMRequest = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="expression.learner" - ) - self.summary_model: LLMRequest = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.summary" - ) - self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 - self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(chat_id) - self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id - - # 学习锁,防止并发执行学习任务 - self._learning_lock = asyncio.Lock() - - async def learn_and_store( - self, - messages: List[Any], - ) -> Optional[List[Tuple[str, str, str]]]: - """ - 学习并存储表达方式 - - Args: - messages: 外部传入的消息列表(必需) - num: 学习数量 - timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time - """ - if not messages: - return None - - random_msg = messages - - # 学习用(开启行编号,便于溯源) - random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True) - - prompt_template = prompt_manager.get_prompt("learn_style") - prompt_template.add_context("bot_name", global_config.bot.nickname) - prompt_template.add_context("chat_str", random_msg_str) - - prompt = await prompt_manager.render_prompt(prompt_template) - - # print(f"random_msg_str:{random_msg_str}") - # logger.info(f"学习{type_str}的prompt: {prompt}") - - try: - response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) - except Exception as e: - logger.error(f"学习表达方式失败,模型生成出错: {e}") - return None - - # 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号) - expressions: List[Tuple[str, str, str]] - jargon_entries: List[Tuple[str, str]] # (content, source_id) - expressions, jargon_entries = parse_expression_response(response) - - # 从缓存中检查 jargon 是否出现在 messages 中 - cached_jargon_entries = self._check_cached_jargons_in_messages(random_msg) - if cached_jargon_entries: - # 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过) - existing_contents = {content for content, _ in jargon_entries} - for content, source_id in cached_jargon_entries: - if content not in existing_contents: - jargon_entries.append((content, source_id)) - existing_contents.add(content) - logger.info(f"从缓存中检查到黑话: {content}") - - # 检查表达方式数量,如果超过10个则放弃本次表达学习 - if len(expressions) > 20: - logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习") - expressions = [] - - # 检查黑话数量,如果超过30个则放弃本次黑话学习 - if len(jargon_entries) > 30: - logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习") - jargon_entries = [] - - # 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话) - if jargon_entries: - await self._process_jargon_entries(jargon_entries, random_msg) - - # 如果没有表达方式,直接返回 - if not expressions: - logger.info("解析后没有可用的表达方式") - return [] - - logger.info(f"学习的prompt: {prompt}") - logger.info(f"学习的expressions: {expressions}") - logger.info(f"学习的jargon_entries: {jargon_entries}") - logger.info(f"学习的response: {response}") - - # 过滤表达方式,根据 source_id 溯源并应用各种过滤规则 - learnt_expressions = self._filter_expressions(expressions, random_msg) - - if learnt_expressions is None: - logger.info("没有学习到表达风格") - return [] - - # 展示学到的表达方式 - learnt_expressions_str = "" - for situation, style in learnt_expressions: - learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") - - current_time = time.time() - - # 存储到数据库 Expression 表 - for situation, style in learnt_expressions: - await self._upsert_expression_record( - situation=situation, - style=style, - current_time=current_time, - ) - - return learnt_expressions - - def _filter_expressions( - self, - expressions: List[Tuple[str, str, str]], - messages: List[Any], - ) -> List[Tuple[str, str, str]]: - """ - 过滤表达方式,移除不符合条件的条目 - - Args: - expressions: 表达方式列表,每个元素是 (situation, style, source_id) - messages: 原始消息列表,用于溯源和验证 - - Returns: - 过滤后的表达方式列表,每个元素是 (situation, style, context) - """ - filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context) - - # 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达) - banned_names = set() - bot_nickname = (global_config.bot.nickname or "").strip() - if bot_nickname: - banned_names.add(bot_nickname) - alias_names = global_config.bot.alias_names or [] - for alias in alias_names: - alias = alias.strip() - if alias: - banned_names.add(alias) - banned_casefold = {name.casefold() for name in banned_names if name} - - for situation, style, source_id in expressions: - source_id_str = (source_id or "").strip() - if not source_id_str.isdigit(): - # 无效的来源行编号,跳过 - continue - - line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始 - if line_index < 0 or line_index >= len(messages): - # 超出范围,跳过 - continue - - # 当前行的原始内容 - current_msg = messages[line_index] - - # 过滤掉从bot自己发言中提取到的表达方式 - if is_bot_message(current_msg): - continue - - context = filter_message_content(current_msg.processed_plain_text or "") - if not context: - continue - - # 过滤掉包含 SELF 的内容(不学习) - if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context: - logger.info(f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - # 过滤掉 style 与机器人名称/昵称重复的表达 - normalized_style = (style or "").strip() - if normalized_style and normalized_style.casefold() in banned_casefold: - logger.debug( - f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}" - ) - continue - - # 过滤掉包含 "表情:" 或 "表情:" 的内容 - if ( - "表情:" in (situation or "") - or "表情:" in (situation or "") - or "表情:" in (style or "") - or "表情:" in (style or "") - or "表情:" in context - or "表情:" in context - ): - logger.info(f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - # 过滤掉包含 "[图片" 的内容 - if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context: - logger.info(f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - filtered_expressions.append((situation, style)) - - return filtered_expressions - - async def _upsert_expression_record( - self, - situation: str, - style: str, - current_time: float, - ) -> None: - # 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list) - # 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理 - expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75) - - if expr_obj: - # 根据相似度决定是否使用 LLM 总结 - # 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结 - use_llm_summary = similarity < 1.0 - await self._update_existing_expression( - expr_obj=expr_obj, - situation=situation, - current_time=current_time, - use_llm_summary=use_llm_summary, - ) - return - - # 没有找到匹配的记录,创建新记录 - await self._create_expression_record( - situation=situation, - style=style, - current_time=current_time, - ) - - async def _create_expression_record( - self, - situation: str, - style: str, - current_time: float, - ) -> None: - content_list = [situation] - # 创建新记录时,直接使用原始的 situation,不进行总结 - formatted_situation = situation - - Expression.create( - situation=formatted_situation, - style=style, - content_list=json.dumps(content_list, ensure_ascii=False), - count=1, - last_active_time=current_time, - chat_id=self.chat_id, - create_date=current_time, - ) - - async def _update_existing_expression( - self, - expr_obj: Expression, - situation: str, - current_time: float, - use_llm_summary: bool = True, - ) -> None: - """ - 更新现有 Expression 记录(situation 完全匹配或相似的情况) - 将新的 situation 添加到 content_list,不合并 style - - Args: - use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True - """ - # 更新 content_list(添加新的 situation) - content_list = self._parse_content_list(expr_obj.content_list) - content_list.append(situation) - expr_obj.content_list = json.dumps(content_list, ensure_ascii=False) - - # 更新其他字段 - expr_obj.count = (expr_obj.count or 0) + 1 - expr_obj.checked = False # count 增加时重置 checked 为 False - expr_obj.last_active_time = current_time - - if use_llm_summary: - # 相似匹配时,使用 LLM 重新组合 situation - new_situation = await self._compose_situation_text( - content_list=content_list, - fallback=expr_obj.situation, - ) - expr_obj.situation = new_situation - - expr_obj.save() - - # count 增加后,立即进行一次检查 - await self._check_expression_immediately(expr_obj) - - def _parse_content_list(self, stored_list: Optional[str]) -> List[str]: - if not stored_list: - return [] - try: - data = json.loads(stored_list) - except json.JSONDecodeError: - return [] - return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else [] - - async def _find_similar_situation_expression( - self, situation: str, similarity_threshold: float = 0.75 - ) -> Tuple[Optional[Expression], float]: - """ - 查找具有相似 situation 的 Expression 记录 - 检查 content_list 中的每一项 - - Args: - situation: 要查找的 situation - similarity_threshold: 相似度阈值,默认 0.75 - - Returns: - Tuple[Optional[Expression], float]: - - 找到的最相似的 Expression 对象,如果没有找到则返回 None - - 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间) - """ - # 查询同一 chat_id 的所有记录 - all_expressions = Expression.select().where(Expression.chat_id == self.chat_id) - - best_match = None - best_similarity = 0.0 - - for expr in all_expressions: - # 检查 content_list 中的每一项 - content_list = self._parse_content_list(expr.content_list) - for existing_situation in content_list: - similarity = calculate_similarity(situation, existing_situation) - if similarity >= similarity_threshold and similarity > best_similarity: - best_similarity = similarity - best_match = expr - - if best_match: - logger.debug( - f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'" - ) - - return best_match, best_similarity - - async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str: - sanitized = [c.strip() for c in content_list if c.strip()] - if not sanitized: - return fallback - - prompt = ( - "请阅读以下多个聊天情境描述,并将它们概括成一句简短的话," - "长度不超过20个字,保留共同特点:\n" - f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。" - ) - - try: - summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2) - summary = summary.strip() - if summary: - return summary - except Exception as e: - logger.error(f"概括表达情境失败: {e}") - return "/".join(sanitized) if sanitized else fallback - - async def _init_check_model(self) -> None: - """初始化检查用的 LLM 实例""" - if self.check_model is None: - try: - self.check_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.check" - ) - logger.debug("检查用 LLM 实例初始化成功") - except Exception as e: - logger.error(f"创建检查用 LLM 实例失败: {e}") - - async def _check_expression_immediately(self, expr_obj: Expression) -> None: - """ - 立即检查表达方式(在 count 增加后调用) - - Args: - expr_obj: 要检查的表达方式对象 - """ - try: - # 检查是否启用自动检查 - if not global_config.expression.expression_self_reflect: - logger.debug("表达方式自动检查未启用,跳过立即检查") - return - - # 初始化检查用的 LLM - await self._init_check_model() - if self.check_model is None: - logger.warning("检查用 LLM 实例初始化失败,跳过立即检查") - return - - # 执行 LLM 评估 - suitable, reason, error = await single_expression_check(expr_obj.situation, expr_obj.style) - - # 更新数据库 - expr_obj.checked = True - expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True - expr_obj.save() - - status = "通过" if suitable else "不通过" - logger.info( - f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | " - f"Situation: {expr_obj.situation[:30]}... | " - f"Style: {expr_obj.style[:30]}... | " - f"Reason: {reason[:50] if reason else '无'}..." - ) - - if error: - logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}") - - except Exception as e: - logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True) - # 检查失败时,保持 checked=False,等待后续自动检查任务处理 - - def _check_cached_jargons_in_messages(self, messages: List[Any]) -> List[Tuple[str, str]]: - """ - 检查缓存中的 jargon 是否出现在 messages 中 - - Args: - messages: 消息列表 - - Returns: - List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id) - """ - if not messages: - return [] - - # 获取 jargon_miner 实例 - jargon_miner = miner_manager.get_miner(self.chat_id) - - # 获取缓存中的所有 jargon - cached_jargons = jargon_miner.get_cached_jargons() - if not cached_jargons: - return [] - - matched_entries: List[Tuple[str, str]] = [] - - # 遍历 messages,检查缓存中的 jargon 是否出现 - for i, msg in enumerate(messages): - # 跳过机器人自己的消息 - if is_bot_message(msg): - continue - - # 获取消息文本 - msg_text = (getattr(msg, "processed_plain_text", None) or "").strip() - - if not msg_text: - continue - - # 检查每个缓存中的 jargon 是否出现在消息文本中 - for jargon in cached_jargons: - if not jargon or not jargon.strip(): - continue - - jargon_content = jargon.strip() - - # 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑) - pattern = re.escape(jargon_content) - # 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界 - if re.search(r"[\u4e00-\u9fff]", jargon_content): - # 包含中文,使用更宽松的匹配 - search_pattern = pattern - else: - # 纯英文/数字,使用单词边界 - search_pattern = r"\b" + pattern + r"\b" - - if re.search(search_pattern, msg_text, re.IGNORECASE): - # 找到匹配,构建条目(source_id 从 1 开始,因为 build_anonymous_messages 的编号从 1 开始) - source_id = str(i + 1) - matched_entries.append((jargon_content, source_id)) - - return matched_entries - - async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: - """ - 处理从 expression learner 提取的黑话条目,路由到 jargon_miner - - Args: - jargon_entries: 黑话条目列表,每个元素是 (content, source_id) - messages: 消息列表,用于构建上下文 - """ - if not jargon_entries or not messages: - return - - # 获取 jargon_miner 实例 - jargon_miner = miner_manager.get_miner(self.chat_id) - - # 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致 - entries: List[Dict[str, List[str]]] = [] - - for content, source_id in jargon_entries: - content = content.strip() - if not content: - continue - - # 过滤掉包含 SELF 的黑话,不学习 - if "SELF" in content: - logger.info(f"跳过包含 SELF 的黑话: {content}") - continue - - # 检查是否包含机器人名称 - if contains_bot_self_name(content): - logger.info(f"跳过包含机器人昵称/别名的黑话: {content}") - continue - - # 解析 source_id - source_id_str = (source_id or "").strip() - if not source_id_str.isdigit(): - logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}") - continue - - # build_anonymous_messages 的编号从 1 开始 - line_index = int(source_id_str) - 1 - if line_index < 0 or line_index >= len(messages): - logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}") - continue - - # 检查是否是机器人自己的消息 - target_msg = messages[line_index] - if is_bot_message(target_msg): - logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}") - continue - - # 构建上下文段落 - context_paragraph = build_context_paragraph(messages, line_index) - if not context_paragraph: - logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}") - continue - - entries.append({"content": content, "raw_content": [context_paragraph]}) - - if not entries: - return - - # 调用 jargon_miner 处理这些条目 - await jargon_miner.process_extracted_entries(entries) - - -class ExpressionLearnerManager: - def __init__(self): - self.expression_learners = {} - - self._ensure_expression_directories() - - def get_expression_learner(self, chat_id: str) -> ExpressionLearner: - if chat_id not in self.expression_learners: - self.expression_learners[chat_id] = ExpressionLearner(chat_id) - return self.expression_learners[chat_id] - - def _ensure_expression_directories(self): - """ - 确保表达方式相关的目录结构存在 - """ - base_dir = os.path.join("data", "expression") - directories_to_create = [ - base_dir, - os.path.join(base_dir, "learnt_style"), - os.path.join(base_dir, "learnt_grammar"), - ] - - for directory in directories_to_create: - try: - os.makedirs(directory, exist_ok=True) - logger.debug(f"确保目录存在: {directory}") - except Exception as e: - logger.error(f"创建目录失败 {directory}: {e}") - - -expression_learner_manager = ExpressionLearnerManager() diff --git a/src/bw_learner/jargon_explainer.py b/src/bw_learner/jargon_explainer.py new file mode 100644 index 00000000..1d392cd2 --- /dev/null +++ b/src/bw_learner/jargon_explainer.py @@ -0,0 +1,86 @@ +from typing import Optional, Dict, List +from sqlmodel import select, func as fn + +import json + +from src.common.database.database import get_db_session +from src.common.database.database_model import Jargon +from src.common.logger import get_logger +from src.config.config import global_config + +logger = get_logger("jargon_explainer") + + +def search_jargon( + keyword: str, + chat_id: Optional[str] = None, + limit: int = 10, + case_sensitive: bool = False, + fuzzy: bool = True, +) -> List[Dict[str, str]]: + """ + 搜索 jargon,支持大小写不敏感和模糊搜索 + + Args: + keyword: 搜索关键词 + chat_id: 可选的聊天 ID(session_id) + - 如果开启了 all_global:此参数被忽略,查询所有 is_global=True 的记录 + - 如果关闭了 all_global:如果提供则优先搜索该聊天或 global 的 jargon + limit: 返回结果数量限制,默认 10 + case_sensitive: 是否大小写敏感,默认 False(不敏感) + fuzzy: 是否模糊搜索,默认 True(使用 LIKE 匹配) + + Returns: + List[Dict[str, str]]: 包含 content, meaning 的字典列表 + """ + if not keyword or not keyword.strip(): + return [] + + keyword = keyword.strip() + + # 构建搜索条件 + if case_sensitive: # 大小写敏感 + search_condition = Jargon.content.contains(keyword) if fuzzy else Jargon.content == keyword # type: ignore + else: + keyword_lower = keyword.lower() + search_condition = ( + fn.LOWER(Jargon.content).contains(keyword_lower) if fuzzy else fn.LOWER(Jargon.content) == keyword_lower + ) + + # 根据 all_global 配置决定查询逻辑同时,限制结果数量(先多取一些,因为后面可能过滤) + if global_config.expression.all_global_jargon: + # 开启 all_global:所有记录都是全局的,查询所有 is_global=True 的记录(无视 chat_id) + query = select(Jargon).where(search_condition, Jargon.is_global).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore + else: + # 关闭 all_global:查询所有记录,chat_id 过滤在 Python 层面进行 + query = select(Jargon).where(search_condition).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore + + # 执行查询并返回结果 + results: List[Dict[str, str]] = [] + with get_db_session() as session: + jargons = session.exec(query).all() + + for jargon in jargons: + # 如果提供了 chat_id 且 all_global=False,需要检查 session_id_dict 是否包含目标 chat_id + if chat_id and not global_config.expression.all_global_jargon and not jargon.is_global: + try: # 解析 session_id_dict + session_id_dict = json.loads(jargon.session_id_dict) if jargon.session_id_dict else {} + except (json.JSONDecodeError, TypeError): + session_id_dict = {} + logger.warning( + f"解析 session_id_dict 失败,jargon_id={jargon.id},原始数据:{jargon.session_id_dict}" + ) + + # 检查是否包含目标 chat_id + if chat_id not in session_id_dict: + continue + # 只返回有 meaning 的记录 + if not jargon.meaning.strip(): + continue + + results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) + # 达到限制数量后停止 + if len(results) >= limit: + break + + return results diff --git a/src/bw_learner/jargon_miner_old.py b/src/bw_learner/jargon_miner_old.py deleted file mode 100644 index d6495291..00000000 --- a/src/bw_learner/jargon_miner_old.py +++ /dev/null @@ -1,589 +0,0 @@ -import json -import asyncio -import random -from collections import OrderedDict -from typing import List, Dict, Optional, Callable -from json_repair import repair_json -from sqlalchemy import func as fn - -from src.common.logger import get_logger -from src.common.database.database_model import Jargon -from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config, global_config -from src.chat.message_receive.chat_stream import get_chat_manager -from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.learner_utils_old import ( - parse_chat_id_list, - chat_id_list_contains, - update_chat_id_list, -) - - -logger = get_logger("jargon") - - -def _is_single_char_jargon(content: str) -> bool: - """ - 判断是否是单字黑话(单个汉字、英文或数字) - - Args: - content: 词条内容 - - Returns: - bool: 如果是单字黑话返回True,否则返回False - """ - if not content or len(content) != 1: - return False - - char = content[0] - # 判断是否是单个汉字、单个英文字母或单个数字 - return ( - "\u4e00" <= char <= "\u9fff" # 汉字 - or "a" <= char <= "z" # 小写字母 - or "A" <= char <= "Z" # 大写字母 - or "0" <= char <= "9" # 数字 - ) - - -def _should_infer_meaning(jargon_obj: Jargon) -> bool: - """ - 判断是否需要进行含义推断 - 在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断 - 并且count必须大于last_inference_count,避免重启后重复判定 - 如果is_complete为True,不再进行推断 - """ - # 如果已完成所有推断,不再推断 - if jargon_obj.is_complete: - return False - - count = jargon_obj.count or 0 - last_inference = jargon_obj.last_inference_count or 0 - - # 阈值列表:3,6, 10, 20, 40, 60, 100 - thresholds = [2, 4, 8, 12, 24, 60, 100] - - if count < thresholds[0]: - return False - - # 如果count没有超过上次判定值,不需要判定 - if count <= last_inference: - return False - - # 找到第一个大于last_inference的阈值 - next_threshold = None - for threshold in thresholds: - if threshold > last_inference: - next_threshold = threshold - break - - # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 - if next_threshold is None: - return False - - # 检查count是否达到或超过这个阈值 - return count >= next_threshold - - -class JargonMiner: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id - - self.llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="jargon.extract", - ) - - self.llm_inference = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="jargon.inference", - ) - - # 初始化stream_name作为类属性,避免重复提取 - chat_manager = get_chat_manager() - stream_name = chat_manager.get_stream_name(self.chat_id) - self.stream_name = stream_name if stream_name else self.chat_id - self.cache_limit = 50 - self.cache: OrderedDict[str, None] = OrderedDict() - - # 黑话提取锁,防止并发执行 - self._extraction_lock = asyncio.Lock() - - def _add_to_cache(self, content: str) -> None: - """将提取到的黑话加入缓存,保持LRU语义""" - if not content: - return - - key = content.strip() - if not key: - return - - # 单字黑话(单个汉字、英文或数字)不记录到缓存 - if _is_single_char_jargon(key): - return - - if key in self.cache: - self.cache.move_to_end(key) - else: - self.cache[key] = None - if len(self.cache) > self.cache_limit: - self.cache.popitem(last=False) - - def get_cached_jargons(self) -> List[str]: - """获取缓存中的所有黑话列表""" - return list(self.cache.keys()) - - async def _infer_meaning_by_id(self, jargon_id: int) -> None: - """通过ID加载对象并推断""" - try: - jargon_obj = Jargon.get_by_id(jargon_id) - # 再次检查is_complete,因为可能在异步任务执行时已被标记为完成 - if jargon_obj.is_complete: - logger.debug(f"jargon {jargon_obj.content} 已完成所有推断,跳过") - return - await self.infer_meaning(jargon_obj) - except Exception as e: - logger.error(f"通过ID推断jargon失败: {e}") - - async def infer_meaning(self, jargon_obj: Jargon) -> None: - """ - 对jargon进行含义推断 - """ - try: - content = jargon_obj.content - raw_content_str = jargon_obj.raw_content or "" - - # 解析raw_content列表 - raw_content_list = [] - if raw_content_str: - try: - raw_content_list = ( - json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str - ) - if not isinstance(raw_content_list, list): - raw_content_list = [raw_content_list] if raw_content_list else [] - except (json.JSONDecodeError, TypeError): - raw_content_list = [raw_content_str] if raw_content_str else [] - - if not raw_content_list: - logger.warning(f"jargon {content} 没有raw_content,跳过推断") - return - - # 获取当前count和上一次的meaning - current_count = jargon_obj.count or 0 - previous_meaning = jargon_obj.meaning or "" - - # 当count为24, 60时,随机移除一半的raw_content项目 - if current_count in [24, 60] and len(raw_content_list) > 1: - # 计算要保留的数量(至少保留1个) - keep_count = max(1, len(raw_content_list) // 2) - raw_content_list = random.sample(raw_content_list, keep_count) - logger.info( - f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目" - ) - - # 步骤1: 基于raw_content和content推断 - raw_content_text = "\n".join(raw_content_list) - - # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 - previous_meaning_section = "" - previous_meaning_instruction = "" - if current_count in [24, 60, 100] and previous_meaning: - previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}" - previous_meaning_instruction = ( - "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" - ) - - prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context") - prompt1_template.add_context("bot_name", global_config.bot.nickname) - prompt1_template.add_context("content", str(content)) - prompt1_template.add_context("raw_content_list", raw_content_text) - prompt1_template.add_context("previous_meaning_section", previous_meaning_section) - prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) - prompt1 = await prompt_manager.render_prompt(prompt1_template) - - response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3) - if not response1: - logger.warning(f"jargon {content} 推断1失败:无响应") - return - - # 解析推断1结果 - inference1 = None - try: - resp1 = response1.strip() - if resp1.startswith("{") and resp1.endswith("}"): - inference1 = json.loads(resp1) - else: - repaired = repair_json(resp1) - inference1 = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(inference1, dict): - logger.warning(f"jargon {content} 推断1结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 推断1解析失败: {e}") - return - - # 检查推断1是否表示信息不足无法推断 - no_info = inference1.get("no_info", False) - meaning1 = inference1.get("meaning", "").strip() - if no_info or not meaning1: - logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新") - # 更新最后一次判定的count值,避免在同一阈值重复尝试 - jargon_obj.last_inference_count = jargon_obj.count or 0 - jargon_obj.save() - return - - # 步骤2: 仅基于content推断 - prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only") - prompt2_template.add_context("content", str(content)) - prompt2 = await prompt_manager.render_prompt(prompt2_template) - - response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3) - if not response2: - logger.warning(f"jargon {content} 推断2失败:无响应") - return - - # 解析推断2结果 - inference2 = None - try: - resp2 = response2.strip() - if resp2.startswith("{") and resp2.endswith("}"): - inference2 = json.loads(resp2) - else: - repaired = repair_json(resp2) - inference2 = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(inference2, dict): - logger.warning(f"jargon {content} 推断2结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 推断2解析失败: {e}") - return - - # logger.info(f"jargon {content} 推断2提示词: {prompt2}") - # logger.info(f"jargon {content} 推断2结果: {response2}") - # logger.info(f"jargon {content} 推断1提示词: {prompt1}") - # logger.info(f"jargon {content} 推断1结果: {response1}") - - if global_config.debug.show_jargon_prompt: - logger.info(f"jargon {content} 推断2提示词: {prompt2}") - logger.info(f"jargon {content} 推断2结果: {response2}") - logger.info(f"jargon {content} 推断1提示词: {prompt1}") - logger.info(f"jargon {content} 推断1结果: {response1}") - else: - logger.debug(f"jargon {content} 推断2提示词: {prompt2}") - logger.debug(f"jargon {content} 推断2结果: {response2}") - logger.debug(f"jargon {content} 推断1提示词: {prompt1}") - logger.debug(f"jargon {content} 推断1结果: {response1}") - - # 步骤3: 比较两个推断结果 - prompt3_template = prompt_manager.get_prompt("jargon_compare_inference") - prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False)) - prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False)) - prompt3 = await prompt_manager.render_prompt(prompt3_template) - - if global_config.debug.show_jargon_prompt: - logger.info(f"jargon {content} 比较提示词: {prompt3}") - - response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3) - if not response3: - logger.warning(f"jargon {content} 比较失败:无响应") - return - - # 解析比较结果 - comparison = None - try: - resp3 = response3.strip() - if resp3.startswith("{") and resp3.endswith("}"): - comparison = json.loads(resp3) - else: - repaired = repair_json(resp3) - comparison = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(comparison, dict): - logger.warning(f"jargon {content} 比较结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 比较解析失败: {e}") - return - - # 判断是否为黑话 - is_similar = comparison.get("is_similar", False) - is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 - - # 更新数据库记录 - jargon_obj.is_jargon = is_jargon - jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else "" - # 更新最后一次判定的count值,避免重启后重复判定 - jargon_obj.last_inference_count = jargon_obj.count or 0 - - # 如果count>=100,标记为完成,不再进行推断 - if (jargon_obj.count or 0) >= 100: - jargon_obj.is_complete = True - - jargon_obj.save() - logger.debug( - f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" - ) - - # 固定输出推断结果,格式化为可读形式 - if is_jargon: - # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx - meaning = jargon_obj.meaning or "无详细说明" - is_global = jargon_obj.is_global - if is_global: - logger.info(f"[黑话]{content}的含义是 {meaning}") - else: - logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}") - else: - # 不是黑话,输出格式:[聊天名]xxx 不是黑话 - logger.info(f"[{self.stream_name}]{content} 不是黑话") - - except Exception as e: - logger.error(f"jargon推断失败: {e}") - import traceback - - traceback.print_exc() - - async def process_extracted_entries( - self, entries: List[Dict[str, List[str]]], person_name_filter: Optional[Callable[[str], bool]] = None - ) -> None: - """ - 处理已提取的黑话条目(从 expression_learner 路由过来的) - - Args: - entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]} - person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称 - """ - if not entries: - return - - try: - # 去重并合并raw_content(按 content 聚合) - merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict() - for entry in entries: - content_key = entry["content"] - - # 检查是否包含人物名称 - # logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}") - # logger.info(f"person_name_filter: {person_name_filter}") - if person_name_filter and person_name_filter(content_key): - logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}") - continue - - raw_list = entry.get("raw_content", []) or [] - if content_key in merged_entries: - merged_entries[content_key]["raw_content"].extend(raw_list) - else: - merged_entries[content_key] = { - "content": content_key, - "raw_content": list(raw_list), - } - - uniq_entries = [] - for merged_entry in merged_entries.values(): - raw_content_list = merged_entry["raw_content"] - if raw_content_list: - merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list)) - uniq_entries.append(merged_entry) - - saved = 0 - updated = 0 - for entry in uniq_entries: - content = entry["content"] - raw_content_list = entry["raw_content"] # 已经是列表 - - try: - # 查询所有content匹配的记录 - query = Jargon.select().where(Jargon.content == content) - - # 查找匹配的记录 - matched_obj = None - for obj in query: - if global_config.expression.all_global_jargon: - # 开启all_global:所有content匹配的记录都可以 - matched_obj = obj - break - else: - # 关闭all_global:需要检查chat_id列表是否包含目标chat_id - chat_id_list = parse_chat_id_list(obj.chat_id) - if chat_id_list_contains(chat_id_list, self.chat_id): - matched_obj = obj - break - - if matched_obj: - obj = matched_obj - try: - obj.count = (obj.count or 0) + 1 - except Exception: - obj.count = 1 - - # 合并raw_content列表:读取现有列表,追加新值,去重 - existing_raw_content = [] - if obj.raw_content: - try: - existing_raw_content = ( - json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content - ) - if not isinstance(existing_raw_content, list): - existing_raw_content = [existing_raw_content] if existing_raw_content else [] - except (json.JSONDecodeError, TypeError): - existing_raw_content = [obj.raw_content] if obj.raw_content else [] - - # 合并并去重 - merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list)) - obj.raw_content = json.dumps(merged_list, ensure_ascii=False) - - # 更新chat_id列表:增加当前chat_id的计数 - chat_id_list = parse_chat_id_list(obj.chat_id) - updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1) - obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False) - - # 开启all_global时,确保记录标记为is_global=True - if global_config.expression.all_global_jargon: - obj.is_global = True - # 关闭all_global时,保持原有is_global不变(不修改) - - obj.save() - - # 检查是否需要推断(达到阈值且超过上次判定值) - if _should_infer_meaning(obj): - # 异步触发推断,不阻塞主流程 - # 重新加载对象以确保数据最新 - jargon_id = obj.id - asyncio.create_task(self._infer_meaning_by_id(jargon_id)) - - updated += 1 - else: - # 没找到匹配记录,创建新记录 - if global_config.expression.all_global_jargon: - # 开启all_global:新记录默认为is_global=True - is_global_new = True - else: - # 关闭all_global:新记录is_global=False - is_global_new = False - - # 使用新格式创建chat_id列表:[[chat_id, count]] - chat_id_list = [[self.chat_id, 1]] - chat_id_json = json.dumps(chat_id_list, ensure_ascii=False) - - Jargon.create( - content=content, - raw_content=json.dumps(raw_content_list, ensure_ascii=False), - chat_id=chat_id_json, - is_global=is_global_new, - count=1, - ) - saved += 1 - except Exception as e: - logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}") - continue - finally: - self._add_to_cache(content) - - # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) - if uniq_entries: - # 收集所有提取的jargon内容 - jargon_list = [entry["content"] for entry in uniq_entries] - jargon_str = ",".join(jargon_list) - - # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) - logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") - - if saved or updated: - logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}") - except Exception as e: - logger.error(f"处理已提取的黑话条目失败: {e}") - - -class JargonMinerManager: - def __init__(self) -> None: - self._miners: dict[str, JargonMiner] = {} - - def get_miner(self, chat_id: str) -> JargonMiner: - if chat_id not in self._miners: - self._miners[chat_id] = JargonMiner(chat_id) - return self._miners[chat_id] - - -miner_manager = JargonMinerManager() - - -def search_jargon( - keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True -) -> List[Dict[str, str]]: - """ - 搜索jargon,支持大小写不敏感和模糊搜索 - - Args: - keyword: 搜索关键词 - chat_id: 可选的聊天ID - - 如果开启了all_global:此参数被忽略,查询所有is_global=True的记录 - - 如果关闭了all_global:如果提供则优先搜索该聊天或global的jargon - limit: 返回结果数量限制,默认10 - case_sensitive: 是否大小写敏感,默认False(不敏感) - fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) - - Returns: - List[Dict[str, str]]: 包含content, meaning的字典列表 - """ - if not keyword or not keyword.strip(): - return [] - - keyword = keyword.strip() - - # 构建查询(选择所有需要的字段,以便后续过滤) - query = Jargon.select() - - # 构建搜索条件 - if case_sensitive: - # 大小写敏感 - if fuzzy: - # 模糊搜索 - search_condition = Jargon.content.contains(keyword) - else: - # 精确匹配 - search_condition = Jargon.content == keyword - else: - # 大小写不敏感 - if fuzzy: - # 模糊搜索(使用LOWER函数) - search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) - else: - # 精确匹配(使用LOWER函数) - search_condition = fn.LOWER(Jargon.content) == keyword.lower() - - query = query.where(search_condition) - - # 根据all_global配置决定查询逻辑 - if global_config.expression.all_global_jargon: - # 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id) - query = query.where(Jargon.is_global) - # 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式 - - # 注意:meaning的过滤移到Python层面,因为我们需要先过滤chat_id - - # 按count降序排序,优先返回出现频率高的 - query = query.order_by(Jargon.count.desc()) - - # 限制结果数量(先多取一些,因为后面可能过滤) - query = query.limit(limit * 2) - - # 执行查询并返回结果,过滤chat_id - results = [] - for jargon in query: - # 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id - if chat_id and not global_config.expression.all_global_jargon: - chat_id_list = parse_chat_id_list(jargon.chat_id) - # 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含 - if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): - continue - - # 只返回有meaning的记录 - if not jargon.meaning or jargon.meaning.strip() == "": - continue - - results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) - - # 达到限制数量后停止 - if len(results) >= limit: - break - - return results diff --git a/src/bw_learner/learner_utils_old.py b/src/bw_learner/learner_utils_old.py index ce3ea379..42fda4b3 100644 --- a/src/bw_learner/learner_utils_old.py +++ b/src/bw_learner/learner_utils_old.py @@ -16,77 +16,6 @@ from json_repair import repair_json logger = get_logger("learner_utils") -def filter_message_content(content: Optional[str]) -> str: - """ - 过滤消息内容,移除回复、@、图片等格式 - - Args: - content: 原始消息内容 - - Returns: - str: 过滤后的内容 - """ - if not content: - return "" - - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r"\[回复.*?\],说:\s*", "", content) - # 移除@<...>格式的内容 - content = re.sub(r"@<[^>]*>", "", content) - # 移除[picid:...]格式的图片ID - content = re.sub(r"\[picid:[^\]]*\]", "", content) - # 移除[表情包:...]格式的内容 - content = re.sub(r"\[表情包:[^\]]*\]", "", content) - - return content.strip() - - -def calculate_similarity(text1: str, text2: str) -> float: - """ - 计算两个文本的相似度,返回0-1之间的值 - 使用SequenceMatcher计算相似度 - - Args: - text1: 第一个文本 - text2: 第二个文本 - - Returns: - float: 相似度值,范围0-1 - """ - return difflib.SequenceMatcher(None, text1, text2).ratio() - - -def calculate_style_similarity(style1: str, style2: str) -> float: - """ - 计算两个 style 的相似度,返回0-1之间的值 - 在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py) - - Args: - style1: 第一个 style - style2: 第二个 style - - Returns: - float: 相似度值,范围0-1 - """ - if not style1 or not style2: - return 0.0 - - # 移除"使用"和"句式"这两个词 - def remove_ignored_words(text: str) -> str: - """移除需要忽略的词""" - text = text.replace("使用", "") - text = text.replace("句式", "") - return text.strip() - - cleaned_style1 = remove_ignored_words(style1) - cleaned_style2 = remove_ignored_words(style2) - - # 如果清理后文本为空,返回0 - if not cleaned_style1 or not cleaned_style2: - return 0.0 - - return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() - def _compute_weights(population: List[Dict]) -> List[float]: """ @@ -275,224 +204,224 @@ def contains_bot_self_name(content: str) -> bool: return any(name in target for name in candidates) -def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: - """ - 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 - """ - if not messages or center_index < 0 or center_index >= len(messages): - return None +# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: +# """ +# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 +# """ +# if not messages or center_index < 0 or center_index >= len(messages): +# return None - context_start = max(0, center_index - 3) - context_end = min(len(messages), center_index + 1 + 3) - context_messages = messages[context_start:context_end] +# context_start = max(0, center_index - 3) +# context_end = min(len(messages), center_index + 1 + 3) +# context_messages = messages[context_start:context_end] - if not context_messages: - return None +# if not context_messages: +# return None - try: - paragraph = build_readable_messages( - messages=context_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - truncate=False, - show_actions=False, - show_pic=True, - message_id_list=None, - remove_emoji_stickers=False, - pic_single=True, - ) - except Exception as e: - logger.warning(f"构建上下文段落失败: {e}") - return None +# try: +# paragraph = build_readable_messages( +# messages=context_messages, +# replace_bot_name=True, +# timestamp_mode="relative", +# read_mark=0.0, +# truncate=False, +# show_actions=False, +# show_pic=True, +# message_id_list=None, +# remove_emoji_stickers=False, +# pic_single=True, +# ) +# except Exception as e: +# logger.warning(f"构建上下文段落失败: {e}") +# return None - paragraph = paragraph.strip() - return paragraph or None +# paragraph = paragraph.strip() +# return paragraph or None -def is_bot_message(msg: Any) -> bool: - """判断消息是否来自机器人自身""" - if msg is None: - return False +# def is_bot_message(msg: Any) -> bool: +# """判断消息是否来自机器人自身""" +# if msg is None: +# return False - bot_config = getattr(global_config, "bot", None) - if not bot_config: - return False +# bot_config = getattr(global_config, "bot", None) +# if not bot_config: +# return False - platform = ( - str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "") - .strip() - .lower() - ) - user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() +# platform = ( +# str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "") +# .strip() +# .lower() +# ) +# user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() - if not platform or not user_id: - return False +# if not platform or not user_id: +# return False - platform_accounts = {} - try: - platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or []) - except Exception: - platform_accounts = {} +# platform_accounts = {} +# try: +# platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or []) +# except Exception: +# platform_accounts = {} - bot_accounts: Dict[str, str] = {} - qq_account = str(getattr(bot_config, "qq_account", "") or "").strip() - if qq_account: - bot_accounts["qq"] = qq_account +# bot_accounts: Dict[str, str] = {} +# qq_account = str(getattr(bot_config, "qq_account", "") or "").strip() +# if qq_account: +# bot_accounts["qq"] = qq_account - telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip() - if telegram_account: - bot_accounts["telegram"] = telegram_account +# telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip() +# if telegram_account: +# bot_accounts["telegram"] = telegram_account - for plat, account in platform_accounts.items(): - if account and plat not in bot_accounts: - bot_accounts[plat] = account +# for plat, account in platform_accounts.items(): +# if account and plat not in bot_accounts: +# bot_accounts[plat] = account - bot_account = bot_accounts.get(platform) - return bool(bot_account and user_id == bot_account) +# bot_account = bot_accounts.get(platform) +# return bool(bot_account and user_id == bot_account) -def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: - """ - 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 +# def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: +# """ +# 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 - 期望的 JSON 结构: - [ - {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 - {"content": "词条", "source_id": "12"}, // 黑话 - ... - ] +# 期望的 JSON 结构: +# [ +# {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 +# {"content": "词条", "source_id": "12"}, // 黑话 +# ... +# ] - Returns: - Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: - 第一个列表是表达方式 (situation, style, source_id) - 第二个列表是黑话 (content, source_id) - """ - if not response: - return [], [] +# Returns: +# Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: +# 第一个列表是表达方式 (situation, style, source_id) +# 第二个列表是黑话 (content, source_id) +# """ +# if not response: +# return [], [] - raw = response.strip() +# raw = response.strip() - # 尝试提取 ```json 代码块 - json_block_pattern = r"```json\s*(.*?)\s*```" - match = re.search(json_block_pattern, raw, re.DOTALL) - if match: - raw = match.group(1).strip() - else: - # 去掉可能存在的通用 ``` 包裹 - raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) - raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) - raw = raw.strip() +# # 尝试提取 ```json 代码块 +# json_block_pattern = r"```json\s*(.*?)\s*```" +# match = re.search(json_block_pattern, raw, re.DOTALL) +# if match: +# raw = match.group(1).strip() +# else: +# # 去掉可能存在的通用 ``` 包裹 +# raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) +# raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) +# raw = raw.strip() - parsed = None - expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) - jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) +# parsed = None +# expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) +# jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) - try: - # 优先尝试直接解析 - if raw.startswith("[") and raw.endswith("]"): - parsed = json.loads(raw) - else: - repaired = repair_json(raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired - except Exception as parse_error: - # 如果解析失败,尝试修复中文引号问题 - # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 - try: +# try: +# # 优先尝试直接解析 +# if raw.startswith("[") and raw.endswith("]"): +# parsed = json.loads(raw) +# else: +# repaired = repair_json(raw) +# if isinstance(repaired, str): +# parsed = json.loads(repaired) +# else: +# parsed = repaired +# except Exception as parse_error: +# # 如果解析失败,尝试修复中文引号问题 +# # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 +# try: - def fix_chinese_quotes_in_json(text): - """使用状态机修复 JSON 字符串值中的中文引号""" - result = [] - i = 0 - in_string = False - escape_next = False +# def fix_chinese_quotes_in_json(text): +# """使用状态机修复 JSON 字符串值中的中文引号""" +# result = [] +# i = 0 +# in_string = False +# escape_next = False - while i < len(text): - char = text[i] +# while i < len(text): +# char = text[i] - if escape_next: - # 当前字符是转义字符后的字符,直接添加 - result.append(char) - escape_next = False - i += 1 - continue +# if escape_next: +# # 当前字符是转义字符后的字符,直接添加 +# result.append(char) +# escape_next = False +# i += 1 +# continue - if char == "\\": - # 转义字符 - result.append(char) - escape_next = True - i += 1 - continue +# if char == "\\": +# # 转义字符 +# result.append(char) +# escape_next = True +# i += 1 +# continue - if char == '"' and not escape_next: - # 遇到英文引号,切换字符串状态 - in_string = not in_string - result.append(char) - i += 1 - continue +# if char == '"' and not escape_next: +# # 遇到英文引号,切换字符串状态 +# in_string = not in_string +# result.append(char) +# i += 1 +# continue - if in_string: - # 在字符串值内部,将中文引号替换为转义的英文引号 - if char == '"': # 中文左引号 U+201C - result.append('\\"') - elif char == '"': # 中文右引号 U+201D - result.append('\\"') - else: - result.append(char) - else: - # 不在字符串内,直接添加 - result.append(char) +# if in_string: +# # 在字符串值内部,将中文引号替换为转义的英文引号 +# if char == '"': # 中文左引号 U+201C +# result.append('\\"') +# elif char == '"': # 中文右引号 U+201D +# result.append('\\"') +# else: +# result.append(char) +# else: +# # 不在字符串内,直接添加 +# result.append(char) - i += 1 +# i += 1 - return "".join(result) +# return "".join(result) - fixed_raw = fix_chinese_quotes_in_json(raw) +# fixed_raw = fix_chinese_quotes_in_json(raw) - # 再次尝试解析 - if fixed_raw.startswith("[") and fixed_raw.endswith("]"): - parsed = json.loads(fixed_raw) - else: - repaired = repair_json(fixed_raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired - except Exception as fix_error: - logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") - logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") - logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") - logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") - return [], [] +# # 再次尝试解析 +# if fixed_raw.startswith("[") and fixed_raw.endswith("]"): +# parsed = json.loads(fixed_raw) +# else: +# repaired = repair_json(fixed_raw) +# if isinstance(repaired, str): +# parsed = json.loads(repaired) +# else: +# parsed = repaired +# except Exception as fix_error: +# logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") +# logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") +# logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") +# logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") +# return [], [] - if isinstance(parsed, dict): - parsed_list = [parsed] - elif isinstance(parsed, list): - parsed_list = parsed - else: - logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") - return [], [] +# if isinstance(parsed, dict): +# parsed_list = [parsed] +# elif isinstance(parsed, list): +# parsed_list = parsed +# else: +# logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") +# return [], [] - for item in parsed_list: - if not isinstance(item, dict): - continue +# for item in parsed_list: +# if not isinstance(item, dict): +# continue - # 检查是否是表达方式条目(有 situation 和 style) - situation = str(item.get("situation", "")).strip() - style = str(item.get("style", "")).strip() - source_id = str(item.get("source_id", "")).strip() +# # 检查是否是表达方式条目(有 situation 和 style) +# situation = str(item.get("situation", "")).strip() +# style = str(item.get("style", "")).strip() +# source_id = str(item.get("source_id", "")).strip() - if situation and style and source_id: - # 表达方式条目 - expressions.append((situation, style, source_id)) - elif item.get("content"): - # 黑话条目(有 content 字段) - content = str(item.get("content", "")).strip() - source_id = str(item.get("source_id", "")).strip() - if content and source_id: - jargon_entries.append((content, source_id)) +# if situation and style and source_id: +# # 表达方式条目 +# expressions.append((situation, style, source_id)) +# elif item.get("content"): +# # 黑话条目(有 content 字段) +# content = str(item.get("content", "")).strip() +# source_id = str(item.get("source_id", "")).strip() +# if content and source_id: +# jargon_entries.append((content, source_id)) - return expressions, jargon_entries +# return expressions, jargon_entries diff --git a/src/bw_learner/message_recorder_old.py b/src/bw_learner/message_recorder_old.py deleted file mode 100644 index a5d90cc9..00000000 --- a/src/bw_learner/message_recorder_old.py +++ /dev/null @@ -1,179 +0,0 @@ -import time -import asyncio -from typing import List, Any -from src.common.logger import get_logger -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive -from src.chat.utils.common_utils import TempMethodsExpression -from src.bw_learner.expression_learner_old import expression_learner_manager -from src.bw_learner.jargon_miner_old import miner_manager - -logger = get_logger("bw_learner") - - -class MessageRecorder: - """ - 统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner - """ - - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id - self.chat_stream = _chat_manager.get_session_by_session_id(chat_id) - self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id - - # 维护每个chat的上次提取时间 - self.last_extraction_time: float = time.time() - - # 提取锁,防止并发执行 - self._extraction_lock = asyncio.Lock() - - # 获取 expression 和 jargon 的配置参数 - self._init_parameters() - - # 获取 expression_learner 和 jargon_miner 实例 - self.expression_learner = expression_learner_manager.get_expression_learner(chat_id) - self.jargon_miner = miner_manager.get_miner(chat_id) - - def _init_parameters(self) -> None: - """初始化提取参数""" - # 获取 expression 配置 - _, self.enable_expression_learning, self.enable_jargon_learning = ( - TempMethodsExpression.get_expression_config_for_chat(self.chat_id) - ) - self.min_messages_for_extraction = 30 - self.min_extraction_interval = 60 - - logger.debug( - f"MessageRecorder 初始化: chat_id={self.chat_id}, " - f"min_messages={self.min_messages_for_extraction}, " - f"min_interval={self.min_extraction_interval}" - ) - - def should_trigger_extraction(self) -> bool: - """ - 检查是否应该触发消息提取 - - Returns: - bool: 是否应该触发提取 - """ - # 检查时间间隔 - time_diff = time.time() - self.last_extraction_time - if time_diff < self.min_extraction_interval: - return False - - # 检查消息数量 - recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_extraction_time, - timestamp_end=time.time(), - ) - - if not recent_messages or len(recent_messages) < self.min_messages_for_extraction: - return False - - return True - - async def extract_and_distribute(self) -> None: - """ - 提取消息并分发给 expression_learner 和 jargon_miner - """ - # 使用异步锁防止并发执行 - async with self._extraction_lock: - # 在锁内检查,避免并发触发 - if not self.should_trigger_extraction(): - return - - # 检查 chat_stream 是否存在 - if not self.chat_stream: - return - - # 记录本次提取的时间窗口,避免重复提取 - extraction_start_time = self.last_extraction_time - extraction_end_time = time.time() - - # 立即更新提取时间,防止并发触发 - self.last_extraction_time = extraction_end_time - - try: - # logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发") - - # 拉取提取窗口内的消息 - messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=extraction_start_time, - timestamp_end=extraction_end_time, - ) - - if not messages: - logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取") - return - - # 按时间排序,确保顺序一致 - messages = sorted(messages, key=lambda msg: msg.time or 0) - - logger.info( - f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息," - f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" - ) - - # 触发 expression_learner 和 jargon_miner 的处理 - if self.enable_expression_learning: - asyncio.create_task(self._trigger_expression_learning(messages)) - - except Exception as e: - logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") - import traceback - - traceback.print_exc() - # 即使失败也保持时间戳更新,避免频繁重试 - - async def _trigger_expression_learning(self, messages: List[Any]) -> None: - """ - 触发 expression 学习,使用指定的消息列表 - - Args: - timestamp_start: 开始时间戳 - timestamp_end: 结束时间戳 - messages: 消息列表 - """ - try: - # 传递消息给 ExpressionLearner(必需参数) - learnt_style = await self.expression_learner.learn_and_store(messages=messages) - - if learnt_style: - logger.info(f"聊天流 {self.chat_name} 表达学习完成") - else: - logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果") - except Exception as e: - logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") - import traceback - - traceback.print_exc() - - -class MessageRecorderManager: - """MessageRecorder 管理器""" - - def __init__(self) -> None: - self._recorders: dict[str, MessageRecorder] = {} - - def get_recorder(self, chat_id: str) -> MessageRecorder: - """获取或创建指定 chat_id 的 MessageRecorder""" - if chat_id not in self._recorders: - self._recorders[chat_id] = MessageRecorder(chat_id) - return self._recorders[chat_id] - - -# 全局管理器实例 -recorder_manager = MessageRecorderManager() - - -async def extract_and_distribute_messages(chat_id: str) -> None: - """ - 统一的消息提取和分发入口函数 - - Args: - chat_id: 聊天流ID - """ - recorder = recorder_manager.get_recorder(chat_id) - await recorder.extract_and_distribute()