diff --git a/AGENTS.md b/AGENTS.md index 226bff82..6577c787 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,4 +9,13 @@ - 对于同一个文件夹下的模块导入,使用相对导入,排列顺序按照**不发生import错误的前提下**,随便排列。 - 对于不同文件夹下的模块导入,使用绝对导入。这些导入应该以`from src`开头,并且按照**不发生import错误的前提下**,尽量使得第二层的文件夹名称相同的导入放在一起;第二层文件夹名称排列随机。 3. 标准库和第三方库的导入应该放在本地模块导入的前面。 -4. 各个导入块之间应该使用一个空行进行分隔。 \ No newline at end of file +4. 各个导入块之间应该使用一个空行进行分隔。 + +# 代码规范 +## 注释规范 +1. 尽量保持良好的注释 +2. 如果原来的代码中有注释,则重构的时候,除非这部分代码被删除,否则相同功能的代码应该保留注释(可以对注释进行修改以保持准确性,但不应该删除注释)。 +3. 如果原来的代码中没有注释,则重构的时候,如果某个功能块的代码较长或者逻辑较为复杂,则应该添加注释来解释这部分代码的功能和逻辑。 +## 类型注解规范 +1. 重构代码时,如果原来的代码中有类型注解,则相同功能的代码应该保留类型注解(可以对类型注解进行修改以保持准确性,但不应该删除类型注解)。 +2. 重构代码时,如果原来的代码中没有类型注解,则重构的时候,如果某个函数的功能较为复杂或者参数较多,则应该添加类型注解来提高代码的可读性和可维护性。(对于简单的变量,可以不添加类型注解) \ No newline at end of file diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index a0a3fe34..0dc4c726 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -1,84 +1,68 @@ -import time -import json -import os -import re +from datetime import datetime +from sqlmodel import select +from typing import TYPE_CHECKING, List, Optional, Tuple + import asyncio -from typing import List, Optional, Tuple, Any, Dict -from src.common.logger import get_logger -from src.common.database.database_model import Expression +import difflib +import json + from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config -from src.chat.utils.chat_message_builder import ( - build_anonymous_messages, -) from src.prompt.prompt_manager import prompt_manager -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.bw_learner.learner_utils import ( - filter_message_content, - is_bot_message, - build_context_paragraph, - contains_bot_self_name, - calculate_similarity, - parse_expression_response, -) -from src.bw_learner.jargon_miner import miner_manager -from src.bw_learner.expression_auto_check_task import ( - single_expression_check, -) +from src.common.logger import get_logger +from src.common.database.database_model import Expression +from src.common.database.database import get_db_session +from src.common.data_models.expression_data_model import MaiExpression +from src.common.utils.utils_message import MessageUtils +from .expression_utils import check_expression_suitability, parse_expression_response + +if TYPE_CHECKING: + from src.chat.message_receive.message import SessionMessage -# MAX_EXPRESSION_COUNT = 300 logger = get_logger("expressor") +# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 +express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner") +summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary") +check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check") + class ExpressionLearner: - def __init__(self, chat_id: str) -> None: - self.express_learn_model: LLMRequest = LLMRequest( - model_set=model_config.model_task_config.utils, request_type="expression.learner" - ) - self.summary_model: LLMRequest = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.summary" - ) - self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 - self.chat_id = chat_id - self.chat_stream = _chat_manager.get_session_by_session_id(chat_id) - self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id + def __init__(self, session_id: str) -> None: + self.session_id = session_id # 学习锁,防止并发执行学习任务 self._learning_lock = asyncio.Lock() - async def learn_and_store( - self, - messages: List[Any], - ) -> Optional[List[Tuple[str, str, str]]]: - """ - 学习并存储表达方式 + # 消息缓存 + self._messages_cache: List["SessionMessage"] = [] - Args: - messages: 外部传入的消息列表(必需) - num: 学习数量 - timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time - """ - if not messages: - return None - - random_msg = messages - - # 学习用(开启行编号,便于溯源) - random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True) + async def add_messages(self, messages: List["SessionMessage"]) -> None: + """添加消息到缓存""" + self._messages_cache.extend(messages) + async def learn(self): + """学习主流程""" + if not self._messages_cache: + logger.debug("没有消息可供学习,跳过学习过程") + return + readable_message, _ = await MessageUtils.build_readable_message( + self._messages_cache, + anonymize=True, + show_lineno=True, + extract_pictures=True, + ) + self._messages_cache.clear() # 学习后清空缓存 prompt_template = prompt_manager.get_prompt("learn_style") prompt_template.add_context("bot_name", global_config.bot.nickname) - prompt_template.add_context("chat_str", random_msg_str) + prompt_template.add_context("chat_str", readable_message) prompt = await prompt_manager.render_prompt(prompt_template) - # print(f"random_msg_str:{random_msg_str}") - # logger.info(f"学习{type_str}的prompt: {prompt}") - try: - response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) + response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3) except Exception as e: logger.error(f"学习表达方式失败,模型生成出错: {e}") return None @@ -87,510 +71,147 @@ class ExpressionLearner: expressions: List[Tuple[str, str, str]] jargon_entries: List[Tuple[str, str]] # (content, source_id) expressions, jargon_entries = parse_expression_response(response) + # TODO: 完成学习 - # 从缓存中检查 jargon 是否出现在 messages 中 - cached_jargon_entries = self._check_cached_jargons_in_messages(random_msg) - if cached_jargon_entries: - # 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过) - existing_contents = {content for content, _ in jargon_entries} - for content, source_id in cached_jargon_entries: - if content not in existing_contents: - jargon_entries.append((content, source_id)) - existing_contents.add(content) - logger.info(f"从缓存中检查到黑话: {content}") - - # 检查表达方式数量,如果超过10个则放弃本次表达学习 - if len(expressions) > 20: - logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习") - expressions = [] - - # 检查黑话数量,如果超过30个则放弃本次黑话学习 - if len(jargon_entries) > 30: - logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习") - jargon_entries = [] - - # 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话) - if jargon_entries: - await self._process_jargon_entries(jargon_entries, random_msg) - - # 如果没有表达方式,直接返回 - if not expressions: - logger.info("解析后没有可用的表达方式") - return [] - - logger.info(f"学习的prompt: {prompt}") - logger.info(f"学习的expressions: {expressions}") - logger.info(f"学习的jargon_entries: {jargon_entries}") - logger.info(f"学习的response: {response}") - - # 过滤表达方式,根据 source_id 溯源并应用各种过滤规则 - learnt_expressions = self._filter_expressions(expressions, random_msg) - - if learnt_expressions is None: - logger.info("没有学习到表达风格") - return [] - - # 展示学到的表达方式 - learnt_expressions_str = "" - for situation, style in learnt_expressions: - learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") - - current_time = time.time() - - # 存储到数据库 Expression 表 - for situation, style in learnt_expressions: - await self._upsert_expression_record( - situation=situation, - style=style, - current_time=current_time, - ) - - return learnt_expressions - - def _filter_expressions( - self, - expressions: List[Tuple[str, str, str]], - messages: List[Any], - ) -> List[Tuple[str, str, str]]: - """ - 过滤表达方式,移除不符合条件的条目 - - Args: - expressions: 表达方式列表,每个元素是 (situation, style, source_id) - messages: 原始消息列表,用于溯源和验证 - - Returns: - 过滤后的表达方式列表,每个元素是 (situation, style, context) - """ - filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context) - - # 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达) - banned_names = set() - bot_nickname = (global_config.bot.nickname or "").strip() - if bot_nickname: - banned_names.add(bot_nickname) - alias_names = global_config.bot.alias_names or [] - for alias in alias_names: - alias = alias.strip() - if alias: - banned_names.add(alias) - banned_casefold = {name.casefold() for name in banned_names if name} - - for situation, style, source_id in expressions: - source_id_str = (source_id or "").strip() - if not source_id_str.isdigit(): - # 无效的来源行编号,跳过 - continue - - line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始 - if line_index < 0 or line_index >= len(messages): - # 超出范围,跳过 - continue - - # 当前行的原始内容 - current_msg = messages[line_index] - - # 过滤掉从bot自己发言中提取到的表达方式 - if is_bot_message(current_msg): - continue - - context = filter_message_content(current_msg.processed_plain_text or "") - if not context: - continue - - # 过滤掉包含 SELF 的内容(不学习) - if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context: - logger.info(f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - # 过滤掉 style 与机器人名称/昵称重复的表达 - normalized_style = (style or "").strip() - if normalized_style and normalized_style.casefold() in banned_casefold: - logger.debug( - f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}" - ) - continue - - # 过滤掉包含 "表情:" 或 "表情:" 的内容 - if ( - "表情:" in (situation or "") - or "表情:" in (situation or "") - or "表情:" in (style or "") - or "表情:" in (style or "") - or "表情:" in context - or "表情:" in context - ): - logger.info(f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - # 过滤掉包含 "[图片" 的内容 - if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context: - logger.info(f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}") - continue - - filtered_expressions.append((situation, style)) - - return filtered_expressions - - async def _upsert_expression_record( - self, - situation: str, - style: str, - current_time: float, - ) -> None: - # 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list) - # 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理 - expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75) - - if expr_obj: + async def _upsert_expression_to_db(self, situation: str, style: str): + expr, similarity = self._find_similar_expression(situation) or (None, 0) + if expr: # 根据相似度决定是否使用 LLM 总结 # 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结 use_llm_summary = similarity < 1.0 - await self._update_existing_expression( - expr_obj=expr_obj, - situation=situation, - current_time=current_time, - use_llm_summary=use_llm_summary, - ) + await self._update_existing_expression(expr, situation, use_llm_summary=use_llm_summary) return - # 没有找到匹配的记录,创建新记录 - await self._create_expression_record( - situation=situation, - style=style, - current_time=current_time, - ) + self._create_expression(situation, style) - async def _create_expression_record( - self, - situation: str, - style: str, - current_time: float, - ) -> None: + def _create_expression(self, situation: str, style: str): content_list = [situation] - # 创建新记录时,直接使用原始的 situation,不进行总结 - formatted_situation = situation + try: + with get_db_session() as db: + new_expr = Expression( + situation=situation, + style=style, + content_list=json.dumps(content_list), + count=1, + session_id=self.session_id, + last_active_time=datetime.now(), + ) + db.add(new_expr) + except Exception as e: + logger.error(f"创建表达方式失败: {e}") - Expression.create( - situation=formatted_situation, - style=style, - content_list=json.dumps(content_list, ensure_ascii=False), - count=1, - last_active_time=current_time, - chat_id=self.chat_id, - create_date=current_time, - ) - - async def _update_existing_expression( - self, - expr_obj: Expression, - situation: str, - current_time: float, - use_llm_summary: bool = True, - ) -> None: - """ - 更新现有 Expression 记录(situation 完全匹配或相似的情况) - 将新的 situation 添加到 content_list,不合并 style - - Args: - use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True - """ - # 更新 content_list(添加新的 situation) - content_list = self._parse_content_list(expr_obj.content_list) - content_list.append(situation) - expr_obj.content_list = json.dumps(content_list, ensure_ascii=False) - - # 更新其他字段 - expr_obj.count = (expr_obj.count or 0) + 1 - expr_obj.checked = False # count 增加时重置 checked 为 False - expr_obj.last_active_time = current_time + async def _update_existing_expression(self, expr: "MaiExpression", situation: str, use_llm_summary: bool = True): + expr.content.append(situation) + expr.count += 1 + expr.checked = False # count 增加时重置 checked 为 False + expr.last_active_time = datetime.now() if use_llm_summary: # 相似匹配时,使用 LLM 重新组合 situation - new_situation = await self._compose_situation_text( - content_list=content_list, - fallback=expr_obj.situation, - ) - expr_obj.situation = new_situation + new_situation = await self._compose_situation_text(expr.content) + if new_situation: + expr.situation = new_situation - expr_obj.save() + try: + with get_db_session() as session: + if expr.item_id is None: + raise ValueError("表达方式对象缺少 item_id,无法更新数据库记录") + statement = select(Expression).filter_by(id=expr.item_id).limit(1) + if db_expr := session.exec(statement).first(): + db_expr.content_list = json.dumps(expr.content) + db_expr.count = expr.count + db_expr.checked = expr.checked + db_expr.last_active_time = expr.last_active_time + db_expr.situation = expr.situation # 更新 situation + session.add(db_expr) + else: + logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新") + except Exception as e: + logger.error(f"更新表达方式失败: {e}") # count 增加后,立即进行一次检查 - await self._check_expression_immediately(expr_obj) - - def _parse_content_list(self, stored_list: Optional[str]) -> List[str]: - if not stored_list: - return [] - try: - data = json.loads(stored_list) - except json.JSONDecodeError: - return [] - return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else [] - - async def _find_similar_situation_expression( - self, situation: str, similarity_threshold: float = 0.75 - ) -> Tuple[Optional[Expression], float]: - """ - 查找具有相似 situation 的 Expression 记录 - 检查 content_list 中的每一项 - - Args: - situation: 要查找的 situation - similarity_threshold: 相似度阈值,默认 0.75 - - Returns: - Tuple[Optional[Expression], float]: - - 找到的最相似的 Expression 对象,如果没有找到则返回 None - - 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间) - """ - # 查询同一 chat_id 的所有记录 - all_expressions = Expression.select().where(Expression.chat_id == self.chat_id) - - best_match = None - best_similarity = 0.0 - - for expr in all_expressions: - # 检查 content_list 中的每一项 - content_list = self._parse_content_list(expr.content_list) - for existing_situation in content_list: - similarity = calculate_similarity(situation, existing_situation) - if similarity >= similarity_threshold and similarity > best_similarity: - best_similarity = similarity - best_match = expr - - if best_match: - logger.debug( - f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'" - ) - - return best_match, best_similarity - - async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str: - sanitized = [c.strip() for c in content_list if c.strip()] - if not sanitized: - return fallback + await self._check_expression(expr) + async def _compose_situation_text(self, content_list: List[str]) -> Optional[str]: + texts = [c.strip() for c in content_list if c.strip()] + if not texts: + return None + description = "\n".join(f"- {s}" for s in texts[-10:]) # 只取最近10条进行概括 prompt = ( - "请阅读以下多个聊天情境描述,并将它们概括成一句简短的话," - "长度不超过20个字,保留共同特点:\n" - f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。" + "请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,长度不超过20个字,保留共同特点:\n" + f"{description}\n" + "只输出概括内容。" ) try: - summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2) - summary = summary.strip() - if summary: + summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2) + if summary := summary.strip(): return summary except Exception as e: - logger.error(f"概括表达情境失败: {e}") - return "/".join(sanitized) if sanitized else fallback + logger.error(f"使用 LLM 生成表达方式概括失败: {e}") + return None - async def _init_check_model(self) -> None: - """初始化检查用的 LLM 实例""" - if self.check_model is None: - try: - self.check_model = LLMRequest( - model_set=model_config.model_task_config.tool_use, request_type="expression.check" - ) - logger.debug("检查用 LLM 实例初始化成功") - except Exception as e: - logger.error(f"创建检查用 LLM 实例失败: {e}") - - async def _check_expression_immediately(self, expr_obj: Expression) -> None: + async def _check_expression(self, expr: "MaiExpression"): """ - 立即检查表达方式(在 count 增加后调用) + 检查表达方式(在 count 增加后调用) Args: - expr_obj: 要检查的表达方式对象 + expr (MaiExpression): 要检查的表达方式对象 """ + if not global_config.expression.expression_self_reflect: + logger.debug("表达方式自我反思功能未启用,跳过检查") + return + + suitable, reason, error = await check_expression_suitability(expr.situation, expr.style) + if error: + logger.error(f"检查表达方式时发生错误: {error}") + return + expr.checked = True + expr.rejected = not suitable + try: - # 检查是否启用自动检查 - if not global_config.expression.expression_self_reflect: - logger.debug("表达方式自动检查未启用,跳过立即检查") - return + with get_db_session() as session: + statement = select(Expression).filter_by(id=expr.item_id).limit(1) + if db_expr := session.exec(statement).first(): + db_expr.checked = expr.checked + db_expr.rejected = expr.rejected + session.add(db_expr) + else: + logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新检查结果") + except Exception as e: + logger.error(f"更新表达方式检查结果失败: {e}") - # 初始化检查用的 LLM - await self._init_check_model() - if self.check_model is None: - logger.warning("检查用 LLM 实例初始化失败,跳过立即检查") - return + status = "通过" if suitable else "不通过" + logger.info( + f"表达方式检查完成 [ID: {expr.item_id}] - {status} | " + f"Situation: {expr.situation[:30]}... | " + f"Style: {expr.style[:30]}... | " + f"Reason: {reason[:50] if reason else '无'}..." + ) - # 执行 LLM 评估 - suitable, reason, error = await single_expression_check(expr_obj.situation, expr_obj.style) + def _find_similar_expression( + self, situation: str, similarity_threshold: float = 0.75 + ) -> Optional[Tuple[MaiExpression, float]]: + """在数据库中查找相似的表达方式""" + try: + with get_db_session() as session: + statement = select(Expression).filter_by(session_id=self.session_id) + expressions = session.exec(statement).all() - # 更新数据库 - expr_obj.checked = True - expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True - expr_obj.save() + best_match: Optional[Expression] = None + best_similarity = 0.0 - status = "通过" if suitable else "不通过" - logger.info( - f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | " - f"Situation: {expr_obj.situation[:30]}... | " - f"Style: {expr_obj.style[:30]}... | " - f"Reason: {reason[:50] if reason else '无'}..." - ) - - if error: - logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}") + for expr in expressions: + content_list = json.loads(expr.content_list) + for situation in content_list: + similarity = difflib.SequenceMatcher(None, situation, expr.situation).ratio() + if similarity > similarity_threshold and similarity > best_similarity: + best_similarity = similarity + best_match = expr + if best_match: + logger.debug(f"找到相似表达方式情景 [ID: {best_match.id}],相似度: {best_similarity:.2f}") + return MaiExpression.from_db_instance(best_match), best_similarity except Exception as e: - logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True) - # 检查失败时,保持 checked=False,等待后续自动检查任务处理 - - def _check_cached_jargons_in_messages(self, messages: List[Any]) -> List[Tuple[str, str]]: - """ - 检查缓存中的 jargon 是否出现在 messages 中 - - Args: - messages: 消息列表 - - Returns: - List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id) - """ - if not messages: - return [] - - # 获取 jargon_miner 实例 - jargon_miner = miner_manager.get_miner(self.chat_id) - - # 获取缓存中的所有 jargon - cached_jargons = jargon_miner.get_cached_jargons() - if not cached_jargons: - return [] - - matched_entries: List[Tuple[str, str]] = [] - - # 遍历 messages,检查缓存中的 jargon 是否出现 - for i, msg in enumerate(messages): - # 跳过机器人自己的消息 - if is_bot_message(msg): - continue - - # 获取消息文本 - msg_text = (getattr(msg, "processed_plain_text", None) or "").strip() - - if not msg_text: - continue - - # 检查每个缓存中的 jargon 是否出现在消息文本中 - for jargon in cached_jargons: - if not jargon or not jargon.strip(): - continue - - jargon_content = jargon.strip() - - # 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑) - pattern = re.escape(jargon_content) - # 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界 - if re.search(r"[\u4e00-\u9fff]", jargon_content): - # 包含中文,使用更宽松的匹配 - search_pattern = pattern - else: - # 纯英文/数字,使用单词边界 - search_pattern = r"\b" + pattern + r"\b" - - if re.search(search_pattern, msg_text, re.IGNORECASE): - # 找到匹配,构建条目(source_id 从 1 开始,因为 build_anonymous_messages 的编号从 1 开始) - source_id = str(i + 1) - matched_entries.append((jargon_content, source_id)) - - return matched_entries - - async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: - """ - 处理从 expression learner 提取的黑话条目,路由到 jargon_miner - - Args: - jargon_entries: 黑话条目列表,每个元素是 (content, source_id) - messages: 消息列表,用于构建上下文 - """ - if not jargon_entries or not messages: - return - - # 获取 jargon_miner 实例 - jargon_miner = miner_manager.get_miner(self.chat_id) - - # 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致 - entries: List[Dict[str, List[str]]] = [] - - for content, source_id in jargon_entries: - content = content.strip() - if not content: - continue - - # 过滤掉包含 SELF 的黑话,不学习 - if "SELF" in content: - logger.info(f"跳过包含 SELF 的黑话: {content}") - continue - - # 检查是否包含机器人名称 - if contains_bot_self_name(content): - logger.info(f"跳过包含机器人昵称/别名的黑话: {content}") - continue - - # 解析 source_id - source_id_str = (source_id or "").strip() - if not source_id_str.isdigit(): - logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}") - continue - - # build_anonymous_messages 的编号从 1 开始 - line_index = int(source_id_str) - 1 - if line_index < 0 or line_index >= len(messages): - logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}") - continue - - # 检查是否是机器人自己的消息 - target_msg = messages[line_index] - if is_bot_message(target_msg): - logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}") - continue - - # 构建上下文段落 - context_paragraph = build_context_paragraph(messages, line_index) - if not context_paragraph: - logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}") - continue - - entries.append({"content": content, "raw_content": [context_paragraph]}) - - if not entries: - return - - # 调用 jargon_miner 处理这些条目 - await jargon_miner.process_extracted_entries(entries) - - -class ExpressionLearnerManager: - def __init__(self): - self.expression_learners = {} - - self._ensure_expression_directories() - - def get_expression_learner(self, chat_id: str) -> ExpressionLearner: - if chat_id not in self.expression_learners: - self.expression_learners[chat_id] = ExpressionLearner(chat_id) - return self.expression_learners[chat_id] - - def _ensure_expression_directories(self): - """ - 确保表达方式相关的目录结构存在 - """ - base_dir = os.path.join("data", "expression") - directories_to_create = [ - base_dir, - os.path.join(base_dir, "learnt_style"), - os.path.join(base_dir, "learnt_grammar"), - ] - - for directory in directories_to_create: - try: - os.makedirs(directory, exist_ok=True) - logger.debug(f"确保目录存在: {directory}") - except Exception as e: - logger.error(f"创建目录失败 {directory}: {e}") - - -expression_learner_manager = ExpressionLearnerManager() + logger.error(f"查找相似表达方式失败: {e}") + return None diff --git a/src/bw_learner/expression_learner_old.py b/src/bw_learner/expression_learner_old.py new file mode 100644 index 00000000..a18fd985 --- /dev/null +++ b/src/bw_learner/expression_learner_old.py @@ -0,0 +1,596 @@ +import time +import json +import os +import re +import asyncio +from typing import List, Optional, Tuple, Any, Dict +from src.common.logger import get_logger +from src.common.database.database_model import Expression +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config, global_config +from src.chat.utils.chat_message_builder import ( + build_anonymous_messages, +) +from src.prompt.prompt_manager import prompt_manager +from src.chat.message_receive.chat_stream import get_chat_manager +from src.bw_learner.learner_utils_old import ( + filter_message_content, + is_bot_message, + build_context_paragraph, + contains_bot_self_name, + calculate_similarity, + parse_expression_response, +) +from src.bw_learner.jargon_miner_old import miner_manager +from src.bw_learner.expression_auto_check_task import ( + single_expression_check, +) + + +# MAX_EXPRESSION_COUNT = 300 + +logger = get_logger("expressor") + + +class ExpressionLearner: + def __init__(self, chat_id: str) -> None: + self.express_learn_model: LLMRequest = LLMRequest( + model_set=model_config.model_task_config.utils, request_type="expression.learner" + ) + self.summary_model: LLMRequest = LLMRequest( + model_set=model_config.model_task_config.tool_use, request_type="expression.summary" + ) + self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 + self.chat_id = chat_id + self.chat_stream = get_chat_manager().get_stream(chat_id) + self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + + # 学习锁,防止并发执行学习任务 + self._learning_lock = asyncio.Lock() + + async def learn_and_store( + self, + messages: List[Any], + ) -> Optional[List[Tuple[str, str, str]]]: + """ + 学习并存储表达方式 + + Args: + messages: 外部传入的消息列表(必需) + num: 学习数量 + timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time + """ + if not messages: + return None + + random_msg = messages + + # 学习用(开启行编号,便于溯源) + random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True) + + prompt_template = prompt_manager.get_prompt("learn_style") + prompt_template.add_context("bot_name", global_config.bot.nickname) + prompt_template.add_context("chat_str", random_msg_str) + + prompt = await prompt_manager.render_prompt(prompt_template) + + # print(f"random_msg_str:{random_msg_str}") + # logger.info(f"学习{type_str}的prompt: {prompt}") + + try: + response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) + except Exception as e: + logger.error(f"学习表达方式失败,模型生成出错: {e}") + return None + + # 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号) + expressions: List[Tuple[str, str, str]] + jargon_entries: List[Tuple[str, str]] # (content, source_id) + expressions, jargon_entries = parse_expression_response(response) + + # 从缓存中检查 jargon 是否出现在 messages 中 + cached_jargon_entries = self._check_cached_jargons_in_messages(random_msg) + if cached_jargon_entries: + # 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过) + existing_contents = {content for content, _ in jargon_entries} + for content, source_id in cached_jargon_entries: + if content not in existing_contents: + jargon_entries.append((content, source_id)) + existing_contents.add(content) + logger.info(f"从缓存中检查到黑话: {content}") + + # 检查表达方式数量,如果超过10个则放弃本次表达学习 + if len(expressions) > 20: + logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习") + expressions = [] + + # 检查黑话数量,如果超过30个则放弃本次黑话学习 + if len(jargon_entries) > 30: + logger.info(f"黑话提取数量超过30个(实际{len(jargon_entries)}个),放弃本次黑话学习") + jargon_entries = [] + + # 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话) + if jargon_entries: + await self._process_jargon_entries(jargon_entries, random_msg) + + # 如果没有表达方式,直接返回 + if not expressions: + logger.info("解析后没有可用的表达方式") + return [] + + logger.info(f"学习的prompt: {prompt}") + logger.info(f"学习的expressions: {expressions}") + logger.info(f"学习的jargon_entries: {jargon_entries}") + logger.info(f"学习的response: {response}") + + # 过滤表达方式,根据 source_id 溯源并应用各种过滤规则 + learnt_expressions = self._filter_expressions(expressions, random_msg) + + if learnt_expressions is None: + logger.info("没有学习到表达风格") + return [] + + # 展示学到的表达方式 + learnt_expressions_str = "" + for situation, style in learnt_expressions: + learnt_expressions_str += f"{situation}->{style}\n" + logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") + + current_time = time.time() + + # 存储到数据库 Expression 表 + for situation, style in learnt_expressions: + await self._upsert_expression_record( + situation=situation, + style=style, + current_time=current_time, + ) + + return learnt_expressions + + def _filter_expressions( + self, + expressions: List[Tuple[str, str, str]], + messages: List[Any], + ) -> List[Tuple[str, str, str]]: + """ + 过滤表达方式,移除不符合条件的条目 + + Args: + expressions: 表达方式列表,每个元素是 (situation, style, source_id) + messages: 原始消息列表,用于溯源和验证 + + Returns: + 过滤后的表达方式列表,每个元素是 (situation, style, context) + """ + filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context) + + # 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达) + banned_names = set() + bot_nickname = (global_config.bot.nickname or "").strip() + if bot_nickname: + banned_names.add(bot_nickname) + alias_names = global_config.bot.alias_names or [] + for alias in alias_names: + alias = alias.strip() + if alias: + banned_names.add(alias) + banned_casefold = {name.casefold() for name in banned_names if name} + + for situation, style, source_id in expressions: + source_id_str = (source_id or "").strip() + if not source_id_str.isdigit(): + # 无效的来源行编号,跳过 + continue + + line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始 + if line_index < 0 or line_index >= len(messages): + # 超出范围,跳过 + continue + + # 当前行的原始内容 + current_msg = messages[line_index] + + # 过滤掉从bot自己发言中提取到的表达方式 + if is_bot_message(current_msg): + continue + + context = filter_message_content(current_msg.processed_plain_text or "") + if not context: + continue + + # 过滤掉包含 SELF 的内容(不学习) + if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context: + logger.info(f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}") + continue + + # 过滤掉 style 与机器人名称/昵称重复的表达 + normalized_style = (style or "").strip() + if normalized_style and normalized_style.casefold() in banned_casefold: + logger.debug( + f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}" + ) + continue + + # 过滤掉包含 "表情:" 或 "表情:" 的内容 + if ( + "表情:" in (situation or "") + or "表情:" in (situation or "") + or "表情:" in (style or "") + or "表情:" in (style or "") + or "表情:" in context + or "表情:" in context + ): + logger.info(f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}") + continue + + # 过滤掉包含 "[图片" 的内容 + if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context: + logger.info(f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}") + continue + + filtered_expressions.append((situation, style)) + + return filtered_expressions + + async def _upsert_expression_record( + self, + situation: str, + style: str, + current_time: float, + ) -> None: + # 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list) + # 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理 + expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75) + + if expr_obj: + # 根据相似度决定是否使用 LLM 总结 + # 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结 + use_llm_summary = similarity < 1.0 + await self._update_existing_expression( + expr_obj=expr_obj, + situation=situation, + current_time=current_time, + use_llm_summary=use_llm_summary, + ) + return + + # 没有找到匹配的记录,创建新记录 + await self._create_expression_record( + situation=situation, + style=style, + current_time=current_time, + ) + + async def _create_expression_record( + self, + situation: str, + style: str, + current_time: float, + ) -> None: + content_list = [situation] + # 创建新记录时,直接使用原始的 situation,不进行总结 + formatted_situation = situation + + Expression.create( + situation=formatted_situation, + style=style, + content_list=json.dumps(content_list, ensure_ascii=False), + count=1, + last_active_time=current_time, + chat_id=self.chat_id, + create_date=current_time, + ) + + async def _update_existing_expression( + self, + expr_obj: Expression, + situation: str, + current_time: float, + use_llm_summary: bool = True, + ) -> None: + """ + 更新现有 Expression 记录(situation 完全匹配或相似的情况) + 将新的 situation 添加到 content_list,不合并 style + + Args: + use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True + """ + # 更新 content_list(添加新的 situation) + content_list = self._parse_content_list(expr_obj.content_list) + content_list.append(situation) + expr_obj.content_list = json.dumps(content_list, ensure_ascii=False) + + # 更新其他字段 + expr_obj.count = (expr_obj.count or 0) + 1 + expr_obj.checked = False # count 增加时重置 checked 为 False + expr_obj.last_active_time = current_time + + if use_llm_summary: + # 相似匹配时,使用 LLM 重新组合 situation + new_situation = await self._compose_situation_text( + content_list=content_list, + fallback=expr_obj.situation, + ) + expr_obj.situation = new_situation + + expr_obj.save() + + # count 增加后,立即进行一次检查 + await self._check_expression_immediately(expr_obj) + + def _parse_content_list(self, stored_list: Optional[str]) -> List[str]: + if not stored_list: + return [] + try: + data = json.loads(stored_list) + except json.JSONDecodeError: + return [] + return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else [] + + async def _find_similar_situation_expression( + self, situation: str, similarity_threshold: float = 0.75 + ) -> Tuple[Optional[Expression], float]: + """ + 查找具有相似 situation 的 Expression 记录 + 检查 content_list 中的每一项 + + Args: + situation: 要查找的 situation + similarity_threshold: 相似度阈值,默认 0.75 + + Returns: + Tuple[Optional[Expression], float]: + - 找到的最相似的 Expression 对象,如果没有找到则返回 None + - 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间) + """ + # 查询同一 chat_id 的所有记录 + all_expressions = Expression.select().where(Expression.chat_id == self.chat_id) + + best_match = None + best_similarity = 0.0 + + for expr in all_expressions: + # 检查 content_list 中的每一项 + content_list = self._parse_content_list(expr.content_list) + for existing_situation in content_list: + similarity = calculate_similarity(situation, existing_situation) + if similarity >= similarity_threshold and similarity > best_similarity: + best_similarity = similarity + best_match = expr + + if best_match: + logger.debug( + f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'" + ) + + return best_match, best_similarity + + async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str: + sanitized = [c.strip() for c in content_list if c.strip()] + if not sanitized: + return fallback + + prompt = ( + "请阅读以下多个聊天情境描述,并将它们概括成一句简短的话," + "长度不超过20个字,保留共同特点:\n" + f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。" + ) + + try: + summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2) + summary = summary.strip() + if summary: + return summary + except Exception as e: + logger.error(f"概括表达情境失败: {e}") + return "/".join(sanitized) if sanitized else fallback + + async def _init_check_model(self) -> None: + """初始化检查用的 LLM 实例""" + if self.check_model is None: + try: + self.check_model = LLMRequest( + model_set=model_config.model_task_config.tool_use, request_type="expression.check" + ) + logger.debug("检查用 LLM 实例初始化成功") + except Exception as e: + logger.error(f"创建检查用 LLM 实例失败: {e}") + + async def _check_expression_immediately(self, expr_obj: Expression) -> None: + """ + 立即检查表达方式(在 count 增加后调用) + + Args: + expr_obj: 要检查的表达方式对象 + """ + try: + # 检查是否启用自动检查 + if not global_config.expression.expression_self_reflect: + logger.debug("表达方式自动检查未启用,跳过立即检查") + return + + # 初始化检查用的 LLM + await self._init_check_model() + if self.check_model is None: + logger.warning("检查用 LLM 实例初始化失败,跳过立即检查") + return + + # 执行 LLM 评估 + suitable, reason, error = await single_expression_check(expr_obj.situation, expr_obj.style) + + # 更新数据库 + expr_obj.checked = True + expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True + expr_obj.save() + + status = "通过" if suitable else "不通过" + logger.info( + f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | " + f"Situation: {expr_obj.situation[:30]}... | " + f"Style: {expr_obj.style[:30]}... | " + f"Reason: {reason[:50] if reason else '无'}..." + ) + + if error: + logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}") + + except Exception as e: + logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True) + # 检查失败时,保持 checked=False,等待后续自动检查任务处理 + + def _check_cached_jargons_in_messages(self, messages: List[Any]) -> List[Tuple[str, str]]: + """ + 检查缓存中的 jargon 是否出现在 messages 中 + + Args: + messages: 消息列表 + + Returns: + List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id) + """ + if not messages: + return [] + + # 获取 jargon_miner 实例 + jargon_miner = miner_manager.get_miner(self.chat_id) + + # 获取缓存中的所有 jargon + cached_jargons = jargon_miner.get_cached_jargons() + if not cached_jargons: + return [] + + matched_entries: List[Tuple[str, str]] = [] + + # 遍历 messages,检查缓存中的 jargon 是否出现 + for i, msg in enumerate(messages): + # 跳过机器人自己的消息 + if is_bot_message(msg): + continue + + # 获取消息文本 + msg_text = (getattr(msg, "processed_plain_text", None) or "").strip() + + if not msg_text: + continue + + # 检查每个缓存中的 jargon 是否出现在消息文本中 + for jargon in cached_jargons: + if not jargon or not jargon.strip(): + continue + + jargon_content = jargon.strip() + + # 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑) + pattern = re.escape(jargon_content) + # 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界 + if re.search(r"[\u4e00-\u9fff]", jargon_content): + # 包含中文,使用更宽松的匹配 + search_pattern = pattern + else: + # 纯英文/数字,使用单词边界 + search_pattern = r"\b" + pattern + r"\b" + + if re.search(search_pattern, msg_text, re.IGNORECASE): + # 找到匹配,构建条目(source_id 从 1 开始,因为 build_anonymous_messages 的编号从 1 开始) + source_id = str(i + 1) + matched_entries.append((jargon_content, source_id)) + + return matched_entries + + async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None: + """ + 处理从 expression learner 提取的黑话条目,路由到 jargon_miner + + Args: + jargon_entries: 黑话条目列表,每个元素是 (content, source_id) + messages: 消息列表,用于构建上下文 + """ + if not jargon_entries or not messages: + return + + # 获取 jargon_miner 实例 + jargon_miner = miner_manager.get_miner(self.chat_id) + + # 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致 + entries: List[Dict[str, List[str]]] = [] + + for content, source_id in jargon_entries: + content = content.strip() + if not content: + continue + + # 过滤掉包含 SELF 的黑话,不学习 + if "SELF" in content: + logger.info(f"跳过包含 SELF 的黑话: {content}") + continue + + # 检查是否包含机器人名称 + if contains_bot_self_name(content): + logger.info(f"跳过包含机器人昵称/别名的黑话: {content}") + continue + + # 解析 source_id + source_id_str = (source_id or "").strip() + if not source_id_str.isdigit(): + logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}") + continue + + # build_anonymous_messages 的编号从 1 开始 + line_index = int(source_id_str) - 1 + if line_index < 0 or line_index >= len(messages): + logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}") + continue + + # 检查是否是机器人自己的消息 + target_msg = messages[line_index] + if is_bot_message(target_msg): + logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}") + continue + + # 构建上下文段落 + context_paragraph = build_context_paragraph(messages, line_index) + if not context_paragraph: + logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}") + continue + + entries.append({"content": content, "raw_content": [context_paragraph]}) + + if not entries: + return + + # 调用 jargon_miner 处理这些条目 + await jargon_miner.process_extracted_entries(entries) + + +class ExpressionLearnerManager: + def __init__(self): + self.expression_learners = {} + + self._ensure_expression_directories() + + def get_expression_learner(self, chat_id: str) -> ExpressionLearner: + if chat_id not in self.expression_learners: + self.expression_learners[chat_id] = ExpressionLearner(chat_id) + return self.expression_learners[chat_id] + + def _ensure_expression_directories(self): + """ + 确保表达方式相关的目录结构存在 + """ + base_dir = os.path.join("data", "expression") + directories_to_create = [ + base_dir, + os.path.join(base_dir, "learnt_style"), + os.path.join(base_dir, "learnt_grammar"), + ] + + for directory in directories_to_create: + try: + os.makedirs(directory, exist_ok=True) + logger.debug(f"确保目录存在: {directory}") + except Exception as e: + logger.error(f"创建目录失败 {directory}: {e}") + + +expression_learner_manager = ExpressionLearnerManager() diff --git a/src/bw_learner/expression_reflect_tracker.py b/src/bw_learner/expression_reflect_tracker.py index 319c8459..25658df7 100644 --- a/src/bw_learner/expression_reflect_tracker.py +++ b/src/bw_learner/expression_reflect_tracker.py @@ -22,6 +22,7 @@ judge_model = LLMRequest(model_set=model_config.model_task_config.tool_use, requ logger = get_logger("reflect_tracker") + class ReflectTracker: def __init__(self, session_id: str): self.session_id = session_id @@ -41,8 +42,8 @@ class ReflectTracker: self.expression = expression self.tracking = True self.tracking_start_time = time.time() - - def _reset_tracker(self): + + def reset_tracker(self): """重置追踪状态""" self.expression = None self.tracking = False @@ -66,122 +67,7 @@ class ReflectTracker: # 检查是否超时(无论是消息数量还是时间) if time.time() - self.tracking_start_time > self.max_duration: - self._reset_tracker() - return True - - # 获取消息列表 - msg_list = get_raw_msg_by_timestamp_with_chat( - chat_id=self.session_id, - timestamp_start=self.tracking_start_time, - timestamp_end=time.time(), - ) - - current_msg_count = len(msg_list) - - # 检查消息数量是否超限 - if current_msg_count > self.max_msg_count: - logger.info(f"ReflectTracker for expr {expr.item_id} timed out (message count).") - self._reset_tracker() + self.reset_tracker() return True - # 如果没有新消息,跳过本次检查 - if current_msg_count <= self.last_check_msg_count: - return False - - self.last_check_msg_count = current_msg_count - - # 构建上下文 - context_block = build_readable_messages( - msg_list, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - show_actions=False, - ) - - # LLM 判断 - try: - prompt_template = prompt_manager.get_prompt("reflect_judge") - prompt_template.add_context("situation", str(expr.situation)) - prompt_template.add_context("style", str(expr.style)) - prompt_template.add_context("context_block", context_block) - prompt = await prompt_manager.render_prompt(prompt_template) - - logger.info(f"ReflectTracker LLM Prompt: {prompt}") - - response, _ = await judge_model.generate_response_async(prompt, temperature=0.1) - - logger.info(f"ReflectTracker LLM Response: {response}") - - # 解析 JSON 响应 - json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - if not matches: - matches = [response] - - json_obj = json.loads(repair_json(matches[0])) - judgment = json_obj.get("judgment") - - if judgment == "Approve": - self._update_expression(checked=True, rejected=False, modified_by="ai") - logger.info(f"Expression {expr.item_id} approved by operator.") - self._reset_tracker() - return True - - elif judgment == "Reject": - corrected_situation = json_obj.get("corrected_situation") - corrected_style = json_obj.get("corrected_style") - has_update = bool(corrected_situation or corrected_style) - - update_kwargs: dict[str, Any] = {"checked": True, "modified_by": "ai"} - if corrected_situation: - update_kwargs["situation"] = corrected_situation - if corrected_style: - update_kwargs["style"] = corrected_style - if not has_update: - update_kwargs["rejected"] = True - else: - update_kwargs["rejected"] = False - - self._update_expression(**update_kwargs) - - if has_update: - logger.info( - f"Expression {expr.item_id} rejected and updated. " - f"New situation: {corrected_situation}, New style: {corrected_style}" - ) - else: - logger.info( - f"Expression {expr.item_id} rejected but no correction provided, marked as rejected." - ) - self._reset_tracker() - return True - - elif judgment == "Ignore": - logger.info(f"ReflectTracker for expr {expr.item_id} judged as Ignore.") - return False - - except Exception as e: - logger.error(f"Error in ReflectTracker check: {e}") - return False - - return False - - def _update_expression(self, **kwargs: Any) -> None: - """更新表达并持久化到数据库""" - if not self.expression: - return - - # 更新内存中的表达对象 - for key, value in kwargs.items(): - if hasattr(self.expression, key): - setattr(self.expression, key, value) - - # 持久化到数据库 - try: - with get_db_session() as session: - db_expr = self.expression.to_db_instance() - session.merge(db_expr) - session.commit() - except Exception as e: - logger.error(f"Failed to persist expression update: {e}") \ No newline at end of file + # TODO: 完成追踪检查逻辑 diff --git a/src/bw_learner/expression_reflector.py b/src/bw_learner/expression_reflector.py index 9c45f22c..bc036077 100644 --- a/src/bw_learner/expression_reflector.py +++ b/src/bw_learner/expression_reflector.py @@ -67,6 +67,10 @@ class ExpressionReflector: logger.debug(f"{LOG_PREFIX} Operator ID 未配置,跳过") return False + if self.reflect_tracker.tracking: + logger.info(f"{LOG_PREFIX} Operator {operator_config} 已有活跃的 Tracker,跳过本次提问") + return False + if allow_reflect_list := global_config.expression.allow_reflect: # 转换配置项为session_id列表 allow_reflect_session_ids = [ @@ -88,9 +92,6 @@ class ExpressionReflector: ) return False - if self.reflect_tracker.tracking: - logger.info(f"{LOG_PREFIX} Operator {operator_config} 已有活跃的 Tracker,跳过本次提问") - return False return True async def ask_reflection(self, operator_config: "TargetItem") -> bool: diff --git a/src/bw_learner/expression_reflector_old.py b/src/bw_learner/expression_reflector_old.py new file mode 100644 index 00000000..d1902f55 --- /dev/null +++ b/src/bw_learner/expression_reflector_old.py @@ -0,0 +1,250 @@ +import random +import time +from typing import Optional, Dict + +from src.common.logger import get_logger +from src.common.database.database_model import Expression +from src.config.config import global_config +from src.chat.message_receive.chat_stream import get_chat_manager +from src.plugin_system.apis import send_api + +logger = get_logger("expression_reflector") + + +class ExpressionReflector: + """表达反思器,管理单个聊天流的表达反思提问""" + + def __init__(self, chat_id: str): + self.chat_id = chat_id + self.last_ask_time: float = 0.0 + + async def check_and_ask(self) -> bool: + """ + 检查是否需要提问表达反思,如果需要则提问 + + Returns: + bool: 是否执行了提问 + """ + try: + logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})") + + if not global_config.expression.expression_manual_reflect: + logger.debug("[Expression Reflection] 表达反思功能未启用,跳过") + return False + + operator_config = global_config.expression.manual_reflect_operator_id + if not operator_config: + logger.debug("[Expression Reflection] Operator ID 未配置,跳过") + return False + + # 检查是否在允许列表中 + allow_reflect = global_config.expression.allow_reflect + if allow_reflect: + # 将 allow_reflect 中的 platform:id:type 格式转换为 chat_id 列表 + allow_reflect_chat_ids = [] + for stream_config in allow_reflect: + parsed_chat_id = global_config.expression._parse_stream_config_to_chat_id(stream_config) + if parsed_chat_id: + allow_reflect_chat_ids.append(parsed_chat_id) + else: + logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}") + + if self.chat_id not in allow_reflect_chat_ids: + logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过") + return False + + # 检查上一次提问时间 + current_time = time.time() + time_since_last_ask = current_time - self.last_ask_time + + # 5-10分钟间隔,随机选择 + min_interval = 10 * 60 # 5分钟 + max_interval = 15 * 60 # 10分钟 + interval = random.uniform(min_interval, max_interval) + + logger.info( + f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)" + ) + + if time_since_last_ask < interval: + remaining_time = interval - time_since_last_ask + logger.info( + f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过" + ) + return False + + # 检查是否已经有针对该 Operator 的 Tracker 在运行 + logger.info(f"[Expression Reflection] 检查 Operator {operator_config} 是否已有活跃的 Tracker") + if await _check_tracker_exists(operator_config): + logger.info(f"[Expression Reflection] Operator {operator_config} 已有活跃的 Tracker,跳过本次提问") + return False + + # 获取未检查的表达 + try: + logger.info("[Expression Reflection] 查询未检查且未拒绝的表达") + expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50) + + expr_list = list(expressions) + logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") + + if not expr_list: + logger.info("[Expression Reflection] 没有可用的表达,跳过") + return False + + target_expr: Expression = random.choice(expr_list) + logger.info( + f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}" + ) + + # 生成询问文本 + ask_text = _generate_ask_text(target_expr) + if not ask_text: + logger.warning("[Expression Reflection] 生成询问文本失败,跳过") + return False + + logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问") + # 发送给 Operator + await _send_to_operator(operator_config, ask_text, target_expr) + + # 更新上一次提问时间 + self.last_ask_time = current_time + logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}") + + return True + + except Exception as e: + logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") + import traceback + + logger.error(traceback.format_exc()) + return False + except Exception as e: + logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") + import traceback + + logger.error(traceback.format_exc()) + return False + + +class ExpressionReflectorManager: + """表达反思管理器,管理多个聊天流的表达反思实例""" + + def __init__(self): + self.reflectors: Dict[str, ExpressionReflector] = {} + + def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector: + """获取或创建指定聊天流的表达反思实例""" + if chat_id not in self.reflectors: + self.reflectors[chat_id] = ExpressionReflector(chat_id) + return self.reflectors[chat_id] + + +# 创建全局实例 +expression_reflector_manager = ExpressionReflectorManager() + + +async def _check_tracker_exists(operator_config: str) -> bool: + """检查指定 Operator 是否已有活跃的 Tracker""" + from src.bw_learner.reflect_tracker import reflect_tracker_manager + + chat_manager = get_chat_manager() + chat_stream = None + + # 尝试解析配置字符串 "platform:id:type" + parts = operator_config.split(":") + if len(parts) == 3: + platform = parts[0] + id_str = parts[1] + stream_type = parts[2] + + user_info = None + group_info = None + + from maim_message import UserInfo, GroupInfo + + if stream_type == "group": + group_info = GroupInfo(group_id=id_str, platform=platform) + user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) + elif stream_type == "private": + user_info = UserInfo(user_id=id_str, platform=platform, user_nickname="Operator") + else: + return False + + if user_info: + try: + chat_stream = await chat_manager.get_or_create_stream(platform, user_info, group_info) + except Exception as e: + logger.error(f"Failed to get or create chat stream for checking tracker: {e}") + return False + else: + chat_stream = chat_manager.get_stream(operator_config) + + if not chat_stream: + return False + + return reflect_tracker_manager.get_tracker(chat_stream.stream_id) is not None + + +def _generate_ask_text(expr: Expression) -> Optional[str]: + try: + ask_text = ( + f"我正在学习新的表达方式,请帮我看看这个是否合适?\n\n" + f"**学习到的表达信息**\n" + f"- 情景 (Situation): {expr.situation}\n" + f"- 风格 (Style): {expr.style}\n" + ) + return ask_text + except Exception as e: + logger.error(f"Failed to generate ask text: {e}") + return None + + +async def _send_to_operator(operator_config: str, text: str, expr: Expression): + chat_manager = get_chat_manager() + chat_stream = None + + # 尝试解析配置字符串 "platform:id:type" + parts = operator_config.split(":") + if len(parts) == 3: + platform = parts[0] + id_str = parts[1] + stream_type = parts[2] + + user_info = None + group_info = None + + from maim_message import UserInfo, GroupInfo + + if stream_type == "group": + group_info = GroupInfo(group_id=id_str, platform=platform) + user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) + elif stream_type == "private": + user_info = UserInfo(user_id=id_str, platform=platform, user_nickname="Operator") + else: + logger.warning(f"Unknown stream type in operator config: {stream_type}") + return + + if user_info: + try: + chat_stream = await chat_manager.get_or_create_stream(platform, user_info, group_info) + except Exception as e: + logger.error(f"Failed to get or create chat stream for operator {operator_config}: {e}") + return + else: + chat_stream = chat_manager.get_stream(operator_config) + + if not chat_stream: + logger.warning(f"Could not find or create chat stream for operator: {operator_config}") + return + + stream_id = chat_stream.stream_id + + # 注册 Tracker + from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager + + tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time()) + reflect_tracker_manager.add_tracker(stream_id, tracker) + + # 发送消息 + await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True) + logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}") diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 78c0948d..6944e1f0 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -9,8 +9,8 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.learner_utils import weighted_sample -from src.common.utils.utils_session import SessionUtils +from src.bw_learner.learner_utils_old import weighted_sample +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") diff --git a/src/bw_learner/expression_utils.py b/src/bw_learner/expression_utils.py new file mode 100644 index 00000000..88237e57 --- /dev/null +++ b/src/bw_learner/expression_utils.py @@ -0,0 +1,212 @@ +from json_repair import repair_json +from typing import Tuple, Optional, List + +import json +import re + +from src.config.config import model_config +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.prompt.prompt_manager import prompt_manager +from src.common.logger import get_logger + +logger = get_logger("expression_utils") + +# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 +judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") + + +async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]: + """ + 执行单次LLM评估 + + Args: + situation: 情境 + style: 风格 + + Returns: + (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 + """ + # 构建评估提示词 + # 基础评估标准 + base_criteria = [ + "表达方式或言语风格是否与使用条件或使用情景匹配", + "允许部分语法错误或口头化或缺省出现", + "表达方式不能太过特指,需要具有泛用性", + "一般不涉及具体的人名或名称", + ] + + if custom_criteria := global_config.expression.expression_auto_check_custom_criteria: + base_criteria.extend(custom_criteria) + + # 构建评估标准列表字符串 + criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)]) + + prompt_template = prompt_manager.get_prompt("expression_evaluation") + prompt_template.add_context("situation", situation) + prompt_template.add_context("style", style) + prompt_template.add_context("criteria_list", criteria_list) + + prompt = await prompt_manager.render_prompt(prompt_template) + + logger.info(f"正在评估表达方式: situation={situation}, style={style}") + + response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024) + + logger.debug(f"评估结果: {response}") + + try: + evaluation = json.loads(response) + except json.JSONDecodeError: + try: + response_repaired = repair_json(response) + evaluation = json.loads(response_repaired) + except Exception as e: + raise ValueError(f"无法解析LLM响应为JSON: {response}") from e + except Exception as e: + return False, f"评估表达方式时发生错误: {e}", str(e) + try: + suitable = evaluation.get("suitable", False) + reason = evaluation.get("reason", "未提供理由") + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") + return suitable, reason, None + except Exception as e: + return False, f"评估结果格式错误: {e}", str(e) + + +def fix_chinese_quotes_in_json(text): + """使用状态机修复 JSON 字符串值中的中文引号""" + result = [] + i = 0 + in_string = False + escape_next = False + + while i < len(text): + char = text[i] + if escape_next: + # 当前字符是转义字符后的字符,直接添加 + result.append(char) + escape_next = False + i += 1 + continue + if char == "\\": + # 转义字符 + result.append(char) + escape_next = True + i += 1 + continue + if char == '"' and not escape_next: + # 遇到英文引号,切换字符串状态 + in_string = not in_string + result.append(char) + i += 1 + continue + if in_string and char in ["“", "”"]: + result.append('\\"') + else: + result.append(char) + i += 1 + + return "".join(result) + + +def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + """ + 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 + + 期望的 JSON 结构: + [ + {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 + {"content": "词条", "source_id": "12"}, // 黑话 + ... + ] + + Returns: + Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + 第一个列表是表达方式 (situation, style, source_id) + 第二个列表是黑话 (content, source_id) + """ + if not response: + return [], [] + + raw = response.strip() + + if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL): + raw = match[1].strip() + else: + # 去掉可能存在的通用 ``` 包裹 + raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) + raw = raw.strip() + + parsed = _try_parse(raw) + if parsed is None: + fixed = fix_chinese_quotes_in_json(raw) + parsed = _try_parse(fixed) + if parsed is None: + logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") + return [], [] + + if isinstance(parsed, dict): + parsed_list = [parsed] + elif isinstance(parsed, list): + parsed_list = parsed + else: + logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") + return [], [] + + expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) + jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) + + for item in parsed_list: + if not isinstance(item, dict): + continue + + # 检查是否是表达方式条目(有 situation 和 style) + situation = str(item.get("situation", "")).strip() + style = str(item.get("style", "")).strip() + source_id = str(item.get("source_id", "")).strip() + + if situation and style and source_id: + # 表达方式条目 + expressions.append((situation, style, source_id)) + continue + content = str(item.get("content", "")).strip() + if content and source_id: + jargon_entries.append((content, source_id)) + + return expressions, jargon_entries + + +def is_single_char_jargon(content: str) -> bool: + """ + 判断是否是单字黑话(单个汉字、英文或数字) + + Args: + content: 词条内容 + + Returns: + bool: 如果是单字黑话返回True,否则返回False + """ + if not content or len(content) != 1: + return False + + char = content[0] + # 判断是否是单个汉字、单个英文字母或单个数字 + return ( + "\u4e00" <= char <= "\u9fff" # 汉字 + or "a" <= char <= "z" # 小写字母 + or "A" <= char <= "Z" # 大写字母 + or "0" <= char <= "9" # 数字 + ) + + +def _try_parse(text): + try: + return json.loads(text) + except Exception: + try: + repaired = repair_json(text) + return json.loads(repaired) + except Exception: + return None diff --git a/src/bw_learner/jargon_explainer.py b/src/bw_learner/jargon_explainer_old.py similarity index 99% rename from src/bw_learner/jargon_explainer.py rename to src/bw_learner/jargon_explainer_old.py index a99ecd84..4d144b2c 100644 --- a/src/bw_learner/jargon_explainer.py +++ b/src/bw_learner/jargon_explainer_old.py @@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.jargon_miner import search_jargon -from src.bw_learner.learner_utils import ( +from src.bw_learner.jargon_miner_old import search_jargon +from src.bw_learner.learner_utils_old import ( is_bot_message, contains_bot_self_name, parse_chat_id_list, diff --git a/src/bw_learner/jargon_miner.py b/src/bw_learner/jargon_miner.py index e3f86bd1..9f4270ad 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -1,595 +1,401 @@ -import json -import asyncio -import random from collections import OrderedDict -from typing import List, Dict, Optional, Callable from json_repair import repair_json -from sqlalchemy import func as fn +from sqlmodel import select +from typing import List, Optional, Dict, Callable, TypedDict, Set + +import asyncio +import json +import random from src.common.logger import get_logger +from src.common.database.database import get_db_session from src.common.database.database_model import Jargon -from src.llm_models.utils_model import LLMRequest +from src.common.data_models.jargon_data_model import MaiJargon from src.config.config import model_config, global_config -from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.llm_models.utils_model import LLMRequest from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.learner_utils import ( - parse_chat_id_list, - chat_id_list_contains, - update_chat_id_list, -) +from .expression_utils import is_single_char_jargon logger = get_logger("jargon") - -def _is_single_char_jargon(content: str) -> bool: - """ - 判断是否是单字黑话(单个汉字、英文或数字) - - Args: - content: 词条内容 - - Returns: - bool: 如果是单字黑话返回True,否则返回False - """ - if not content or len(content) != 1: - return False - - char = content[0] - # 判断是否是单个汉字、单个英文字母或单个数字 - return ( - "\u4e00" <= char <= "\u9fff" # 汉字 - or "a" <= char <= "z" # 小写字母 - or "A" <= char <= "Z" # 大写字母 - or "0" <= char <= "9" # 数字 - ) +# TODO: 重构完LLM相关内容后,替换成新的模型调用方式 +llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract") +llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference") -def _should_infer_meaning(jargon_obj: Jargon) -> bool: - """ - 判断是否需要进行含义推断 - 在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断 - 并且count必须大于last_inference_count,避免重启后重复判定 - 如果is_complete为True,不再进行推断 - """ - # 如果已完成所有推断,不再推断 - if jargon_obj.is_complete: - return False +class JargonEntry(TypedDict): + content: str + raw_content: Set[str] - count = jargon_obj.count or 0 - last_inference = jargon_obj.last_inference_count or 0 - # 阈值列表:3,6, 10, 20, 40, 60, 100 - thresholds = [2, 4, 8, 12, 24, 60, 100] - - if count < thresholds[0]: - return False - - # 如果count没有超过上次判定值,不需要判定 - if count <= last_inference: - return False - - # 找到第一个大于last_inference的阈值 - next_threshold = None - for threshold in thresholds: - if threshold > last_inference: - next_threshold = threshold - break - - # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 - if next_threshold is None: - return False - - # 检查count是否达到或超过这个阈值 - return count >= next_threshold +class JargonMeaningEntry(TypedDict): + content: str + meaning: str class JargonMiner: - def __init__(self, chat_id: str) -> None: - self.chat_id = chat_id + def __init__(self, session_id: str, session_name: str) -> None: + self.session_id = session_id + self.session_name = session_name - self.llm = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="jargon.extract", - ) - - self.llm_inference = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="jargon.inference", - ) - - # 初始化stream_name作为类属性,避免重复提取 - chat_manager = _chat_manager - stream_name = chat_manager.get_session_name(self.chat_id) - self.stream_name = stream_name or self.chat_id + # Cache 相关 self.cache_limit = 50 self.cache: OrderedDict[str, None] = OrderedDict() - # 黑话提取锁,防止并发执行 self._extraction_lock = asyncio.Lock() - def _add_to_cache(self, content: str) -> None: - """将提取到的黑话加入缓存,保持LRU语义""" - if not content: - return - - key = content.strip() - if not key: - return - - # 单字黑话(单个汉字、英文或数字)不记录到缓存 - if _is_single_char_jargon(key): - return - - if key in self.cache: - self.cache.move_to_end(key) - else: - self.cache[key] = None - if len(self.cache) > self.cache_limit: - self.cache.popitem(last=False) - def get_cached_jargons(self) -> List[str]: """获取缓存中的所有黑话列表""" return list(self.cache.keys()) - async def _infer_meaning_by_id(self, jargon_id: int) -> None: - """通过ID加载对象并推断""" - try: - jargon_obj = Jargon.get_by_id(jargon_id) - # 再次检查is_complete,因为可能在异步任务执行时已被标记为完成 - if jargon_obj.is_complete: - logger.debug(f"jargon {jargon_obj.content} 已完成所有推断,跳过") - return - await self.infer_meaning(jargon_obj) - except Exception as e: - logger.error(f"通过ID推断jargon失败: {e}") - - async def infer_meaning(self, jargon_obj: Jargon) -> None: + async def infer_meaning(self, jargon_obj: MaiJargon) -> None: """ 对jargon进行含义推断 """ - try: - content = jargon_obj.content - raw_content_str = jargon_obj.raw_content or "" - - # 解析raw_content列表 - raw_content_list = [] - if raw_content_str: - try: - raw_content_list = ( - json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str - ) - if not isinstance(raw_content_list, list): - raw_content_list = [raw_content_list] if raw_content_list else [] - except (json.JSONDecodeError, TypeError): - raw_content_list = [raw_content_str] if raw_content_str else [] - - if not raw_content_list: - logger.warning(f"jargon {content} 没有raw_content,跳过推断") - return - - # 获取当前count和上一次的meaning - current_count = jargon_obj.count or 0 - previous_meaning = jargon_obj.meaning or "" - - # 当count为24, 60时,随机移除一半的raw_content项目 - if current_count in [24, 60] and len(raw_content_list) > 1: - # 计算要保留的数量(至少保留1个) - keep_count = max(1, len(raw_content_list) // 2) - raw_content_list = random.sample(raw_content_list, keep_count) - logger.info( - f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目" - ) - - # 步骤1: 基于raw_content和content推断 - raw_content_text = "\n".join(raw_content_list) - - # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 - previous_meaning_section = "" - previous_meaning_instruction = "" - if current_count in [24, 60, 100] and previous_meaning: - previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}" - previous_meaning_instruction = ( - "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" - ) - - prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context") - prompt1_template.add_context("bot_name", global_config.bot.nickname) - prompt1_template.add_context("content", str(content)) - prompt1_template.add_context("raw_content_list", raw_content_text) - prompt1_template.add_context("previous_meaning_section", previous_meaning_section) - prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) - prompt1 = await prompt_manager.render_prompt(prompt1_template) - - response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3) - if not response1: - logger.warning(f"jargon {content} 推断1失败:无响应") - return - - # 解析推断1结果 - inference1 = None + content = jargon_obj.content + # 解析raw_content列表 + raw_content_list = [] + if raw_content_str := jargon_obj.raw_content: try: - resp1 = response1.strip() - if resp1.startswith("{") and resp1.endswith("}"): - inference1 = json.loads(resp1) - else: - repaired = repair_json(resp1) - inference1 = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(inference1, dict): - logger.warning(f"jargon {content} 推断1结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 推断1解析失败: {e}") - return + raw_content_list = json.loads(raw_content_str) + if not isinstance(raw_content_list, list): + raw_content_list = [raw_content_list] if raw_content_list else [] + except (json.JSONDecodeError, TypeError): + raw_content_list = [raw_content_str] if raw_content_str else [] - # 检查推断1是否表示信息不足无法推断 - no_info = inference1.get("no_info", False) - meaning1 = inference1.get("meaning", "").strip() - if no_info or not meaning1: - logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新") - # 更新最后一次判定的count值,避免在同一阈值重复尝试 - jargon_obj.last_inference_count = jargon_obj.count or 0 - jargon_obj.save() - return + if not raw_content_list: + logger.warning(f"jargon {content} 没有raw_content,跳过推断") + return - # 步骤2: 仅基于content推断 - prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only") - prompt2_template.add_context("content", str(content)) - prompt2 = await prompt_manager.render_prompt(prompt2_template) + # 获取当前count和上一次的meaning + current_count = jargon_obj.count + previous_meaning = jargon_obj.meaning - response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3) - if not response2: - logger.warning(f"jargon {content} 推断2失败:无响应") - return + # 步骤1: 基于raw_content和content推断 + raw_content_text = "\n".join(raw_content_list) - # 解析推断2结果 - inference2 = None - try: - resp2 = response2.strip() - if resp2.startswith("{") and resp2.endswith("}"): - inference2 = json.loads(resp2) - else: - repaired = repair_json(resp2) - inference2 = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(inference2, dict): - logger.warning(f"jargon {content} 推断2结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 推断2解析失败: {e}") - return - - # logger.info(f"jargon {content} 推断2提示词: {prompt2}") - # logger.info(f"jargon {content} 推断2结果: {response2}") - # logger.info(f"jargon {content} 推断1提示词: {prompt1}") - # logger.info(f"jargon {content} 推断1结果: {response1}") - - if global_config.debug.show_jargon_prompt: - logger.info(f"jargon {content} 推断2提示词: {prompt2}") - logger.info(f"jargon {content} 推断2结果: {response2}") - logger.info(f"jargon {content} 推断1提示词: {prompt1}") - logger.info(f"jargon {content} 推断1结果: {response1}") - else: - logger.debug(f"jargon {content} 推断2提示词: {prompt2}") - logger.debug(f"jargon {content} 推断2结果: {response2}") - logger.debug(f"jargon {content} 推断1提示词: {prompt1}") - logger.debug(f"jargon {content} 推断1结果: {response1}") - - # 步骤3: 比较两个推断结果 - prompt3_template = prompt_manager.get_prompt("jargon_compare_inference") - prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False)) - prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False)) - prompt3 = await prompt_manager.render_prompt(prompt3_template) - - if global_config.debug.show_jargon_prompt: - logger.info(f"jargon {content} 比较提示词: {prompt3}") - - response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3) - if not response3: - logger.warning(f"jargon {content} 比较失败:无响应") - return - - # 解析比较结果 - comparison = None - try: - resp3 = response3.strip() - if resp3.startswith("{") and resp3.endswith("}"): - comparison = json.loads(resp3) - else: - repaired = repair_json(resp3) - comparison = json.loads(repaired) if isinstance(repaired, str) else repaired - if not isinstance(comparison, dict): - logger.warning(f"jargon {content} 比较结果格式错误") - return - except Exception as e: - logger.error(f"jargon {content} 比较解析失败: {e}") - return - - # 判断是否为黑话 - is_similar = comparison.get("is_similar", False) - is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 - - # 更新数据库记录 - jargon_obj.is_jargon = is_jargon - if is_jargon: - # 是黑话,使用推断1的结果(基于上下文,更准确) - jargon_obj.meaning = inference1.get("meaning", "") - else: - # 不是黑话,清空含义,不再存储任何内容 - jargon_obj.meaning = "" - - # 更新最后一次判定的count值,避免重启后重复判定 - jargon_obj.last_inference_count = jargon_obj.count or 0 - - # 如果count>=100,标记为完成,不再进行推断 - if (jargon_obj.count or 0) >= 100: - jargon_obj.is_complete = True - - jargon_obj.save() - logger.debug( - f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" + # 当count为24, 60时,随机移除一半的raw_content项目 + if current_count in [24, 60] and len(raw_content_list) > 1: + # 计算要保留的数量(至少保留1个) + keep_count = max(1, len(raw_content_list) // 2) + raw_content_list = random.sample(raw_content_list, keep_count) + logger.info( + f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目" ) - # 固定输出推断结果,格式化为可读形式 - if is_jargon: - # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx - meaning = jargon_obj.meaning or "无详细说明" - is_global = jargon_obj.is_global - if is_global: - logger.info(f"[黑话]{content}的含义是 {meaning}") - else: - logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}") - else: - # 不是黑话,输出格式:[聊天名]xxx 不是黑话 - logger.info(f"[{self.stream_name}]{content} 不是黑话") + # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 + previous_meaning_section = "" + previous_meaning_instruction = "" + if current_count in [24, 60, 100] and previous_meaning: + previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}" + previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" + prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context") + prompt1_template.add_context("bot_name", global_config.bot.nickname) + prompt1_template.add_context("content", str(content)) + prompt1_template.add_context("raw_content_list", raw_content_text) + prompt1_template.add_context("previous_meaning_section", previous_meaning_section) + prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) + prompt1 = await prompt_manager.render_prompt(prompt1_template) + + llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3) + if not llm_response_1: + logger.warning(f"jargon {content} 推断1失败:无响应") + return + + # 解析推断1结果 + inference1 = self._parse_result(llm_response_1) + if not inference1: + logger.warning(f"jargon {content} 推断1解析失败") + return + + no_info = inference1.get("no_info", False) + meaning1: str = inference1.get("meaning", "").strip() + if no_info or not meaning1: + logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新") + # 更新最后一次判定的count值,避免在同一阈值重复尝试 + jargon_obj.last_inference_count = jargon_obj.count or 0 + + try: + self._modify_jargon_entry(jargon_obj) + except Exception as e: + logger.error(f"jargon {content} 推断1更新last_inference_count失败: {e}") + return + + # 步骤2: 基于content-only进行推断 + prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only") + prompt2_template.add_context("content", content) + prompt2 = await prompt_manager.render_prompt(prompt2_template) + + llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3) + if not llm_response_2: + logger.warning(f"jargon {content} 推断2失败:无响应") + return + + # 解析推断2结果 + inference2 = self._parse_result(llm_response_2) + if not inference2: + logger.warning(f"jargon {content} 推断2解析失败") + return + + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 推断1提示词: {prompt1}") + logger.info(f"jargon {content} 推断2提示词: {prompt2}") + + # 步骤3: 比较两个推断结果 + prompt3_template = prompt_manager.get_prompt("jargon_compare_inference") + prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False)) + prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False)) + prompt3 = await prompt_manager.render_prompt(prompt3_template) + + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 比较提示词: {prompt3}") + + llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3) + if not llm_response_3: + logger.warning(f"jargon {content} 比较失败:无响应") + return + + comparison_result = self._parse_result(llm_response_3) + if not comparison_result: + logger.warning(f"jargon {content} 比较解析失败") + return + + is_similar = comparison_result.get("is_similar", False) + is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 + + # 更新数据库记录 + jargon_obj.is_jargon = is_jargon + jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else "" + # 更新最后一次判定的count值,避免重启后重复判定 + jargon_obj.last_inference_count = jargon_obj.count or 0 + + # 如果count>=100,标记为完成,不再进行推断 + if (jargon_obj.count or 0) >= 100: + jargon_obj.is_complete = True + + try: + self._modify_jargon_entry(jargon_obj) except Exception as e: - logger.error(f"jargon推断失败: {e}") - import traceback + logger.error(f"jargon {content} 推断结果更新失败: {e}") + logger.debug( + f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" + ) - traceback.print_exc() + # 固定输出推断结果,格式化为可读形式 + if is_jargon: + # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx + meaning = jargon_obj.meaning or "无详细说明" + is_global = jargon_obj.is_global # 是否为全局的 + if is_global: + logger.info(f"[黑话]{content}的含义是 {meaning}") + else: + logger.info(f"[{self.session_name}]{content}的含义是 {meaning}") + else: + # 不是黑话,输出格式:[聊天名]xxx 不是黑话 + logger.info(f"[{self.session_name}]{content} 不是黑话") async def process_extracted_entries( - self, entries: List[Dict[str, List[str]]], person_name_filter: Optional[Callable[[str], bool]] = None - ) -> None: + self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] + ): """ 处理已提取的黑话条目(从 expression_learner 路由过来的) Args: - entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]} + entries: 黑话条目列表 person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称 """ if not entries: return + merged_entries: Dict[str, JargonEntry] = {} + for entry in entries: + content = entry["content"].strip() - try: - # 去重并合并raw_content(按 content 聚合) - merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict() - for entry in entries: - content_key = entry["content"] + if person_name_filter and person_name_filter(content): + logger.info(f"条目 '{content}' 包含人物名称,已过滤") + continue + raw_list = entry["raw_content"] or set() + if content in merged_entries: + merged_entries[content]["raw_content"].update(raw_list) + else: + merged_entries[content] = {"content": content, "raw_content": set(raw_list)} - # 检查是否包含人物名称 - # logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}") - # logger.info(f"person_name_filter: {person_name_filter}") - if person_name_filter and person_name_filter(content_key): - logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}") - continue + uniq_entries: List[JargonEntry] = list(merged_entries.values()) - raw_list = entry.get("raw_content", []) or [] - if content_key in merged_entries: - merged_entries[content_key]["raw_content"].extend(raw_list) + saved = 0 + updated = 0 + for entry in uniq_entries: + content = entry["content"] + raw_content_set = entry["raw_content"] + try: + with get_db_session() as session: + jargon_items = session.exec(select(Jargon).filter_by(content=content)).all() + except Exception as e: + logger.error(f"查询黑话 '{content}' 失败: {e}") + continue + # 找匹配项 + matched_jargon: Optional[Jargon] = None + for item in jargon_items: + if global_config.expression.all_global_jargon: + # 开启all_global:所有content匹配的记录都可以 + matched_jargon = item + break else: - merged_entries[content_key] = { - "content": content_key, - "raw_content": list(raw_list), - } - - uniq_entries = [] - for merged_entry in merged_entries.values(): - raw_content_list = merged_entry["raw_content"] - if raw_content_list: - merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list)) - uniq_entries.append(merged_entry) - - saved = 0 - updated = 0 - for entry in uniq_entries: - content = entry["content"] - raw_content_list = entry["raw_content"] # 已经是列表 - - try: - # 查询所有content匹配的记录 - query = Jargon.select().where(Jargon.content == content) - - # 查找匹配的记录 - matched_obj = None - for obj in query: - if global_config.expression.all_global_jargon: - # 开启all_global:所有content匹配的记录都可以 - matched_obj = obj - break - else: - # 关闭all_global:需要检查chat_id列表是否包含目标chat_id - chat_id_list = parse_chat_id_list(obj.chat_id) - if chat_id_list_contains(chat_id_list, self.chat_id): - matched_obj = obj - break - - if matched_obj: - obj = matched_obj + # 检查列表是否包含目标session_id + if item.session_id_dict: try: - obj.count = (obj.count or 0) + 1 - except Exception: - obj.count = 1 - - # 合并raw_content列表:读取现有列表,追加新值,去重 - existing_raw_content = [] - if obj.raw_content: - try: - existing_raw_content = ( - json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content - ) - if not isinstance(existing_raw_content, list): - existing_raw_content = [existing_raw_content] if existing_raw_content else [] - except (json.JSONDecodeError, TypeError): - existing_raw_content = [obj.raw_content] if obj.raw_content else [] - - # 合并并去重 - merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list)) - obj.raw_content = json.dumps(merged_list, ensure_ascii=False) - - # 更新chat_id列表:增加当前chat_id的计数 - chat_id_list = parse_chat_id_list(obj.chat_id) - updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1) - obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False) - - # 开启all_global时,确保记录标记为is_global=True - if global_config.expression.all_global_jargon: - obj.is_global = True - # 关闭all_global时,保持原有is_global不变(不修改) - - obj.save() - - # 检查是否需要推断(达到阈值且超过上次判定值) - if _should_infer_meaning(obj): - # 异步触发推断,不阻塞主流程 - # 重新加载对象以确保数据最新 - jargon_id = obj.id - asyncio.create_task(self._infer_meaning_by_id(jargon_id)) - - updated += 1 - else: - # 没找到匹配记录,创建新记录 - if global_config.expression.all_global_jargon: - # 开启all_global:新记录默认为is_global=True - is_global_new = True - else: - # 关闭all_global:新记录is_global=False - is_global_new = False - - # 使用新格式创建chat_id列表:[[chat_id, count]] - chat_id_list = [[self.chat_id, 1]] - chat_id_json = json.dumps(chat_id_list, ensure_ascii=False) - - Jargon.create( - content=content, - raw_content=json.dumps(raw_content_list, ensure_ascii=False), - chat_id=chat_id_json, - is_global=is_global_new, - count=1, - ) - saved += 1 + session_id_dict = json.loads(item.session_id_dict) + if self.session_id in session_id_dict: + matched_jargon = item + break + except Exception as e: + logger.error(f"解析Jargon id={item.id} session_id_list失败: {e}") + continue + if matched_jargon: + # 已存在记录,更新count和raw_content + self._update_jargon(matched_jargon, raw_content_set) + if self._should_infer_meaning(matched_jargon): + asyncio.create_task(self._infer_meaning_by_id(matched_jargon.id)) # type: ignore + updated += 1 + else: + # 没找到匹配记录,创建新记录 + is_global_new = global_config.expression.all_global_jargon + session_dict_str = json.dumps({self.session_id: 1}) + new_jargon = Jargon( + content=content, + raw_content=json.dumps(list(raw_content_set), ensure_ascii=False), + session_id_dict=session_dict_str, + is_global=is_global_new, + count=1, + meaning="", + ) + try: + with get_db_session() as session: + session.add(new_jargon) except Exception as e: - logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}") + logger.error(f"保存新黑话 '{content}' 失败: {e}") continue finally: self._add_to_cache(content) + # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) + if uniq_entries: + # 收集所有提取的jargon内容 + jargon_list = [entry["content"] for entry in uniq_entries] + jargon_str = ",".join(jargon_list) + logger.info(f"[{self.session_name}]疑似黑话: {jargon_str}") - # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) - if uniq_entries: - # 收集所有提取的jargon内容 - jargon_list = [entry["content"] for entry in uniq_entries] - jargon_str = ",".join(jargon_list) + if saved or updated: + logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,session_id={self.session_id}") - # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) - logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") + def _add_to_cache(self, content: str): + """将黑话内容添加到缓存,并维护缓存大小""" + content = content.strip() + if is_single_char_jargon(content): + return + if content in self.cache: + # 已存在,移动到末尾表示最近使用 + self.cache.move_to_end(content) + else: + # 新内容,添加到缓存 + self.cache[content] = None + # 如果超过限制,移除最旧的项 + if len(self.cache) > self.cache_limit: + removed_content, _ = self.cache.popitem(last=False) + logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}") - if saved or updated: - logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}") + def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]): + db_jargon.count += 1 + existing_raw_content: List[str] = [] + if db_jargon.raw_content: + try: + existing_raw_content = json.loads(db_jargon.raw_content) + except Exception: + existing_raw_content = [] + + # 合并去重 + merged_list = list(set(existing_raw_content).union(raw_content_set)) + db_jargon.raw_content = json.dumps(merged_list, ensure_ascii=False) + session_id_dict: Dict[str, int] = json.loads(db_jargon.session_id_dict) + session_id_dict[self.session_id] = session_id_dict.get(self.session_id, 0) + 1 + db_jargon.session_id_dict = json.dumps(session_id_dict) + + # 开启all_global时,确保记录标记为is_global=True + if global_config.expression.all_global_jargon: + db_jargon.is_global = True + + try: + with get_db_session() as session: + session.add(db_jargon) except Exception as e: - logger.error(f"处理已提取的黑话条目失败: {e}") + logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}") + def _parse_result(self, response: str) -> Optional[Dict[str, str]]: + try: + result = json.loads(response.strip()) + except Exception: + try: + repaired = repair_json(response.strip()) + result = json.loads(repaired) + except Exception as e2: + logger.error(f"推断结果解析失败: {e2}") + return None + if not isinstance(result, dict): + logger.warning("推断结果格式错误") + return None + return result -class JargonMinerManager: - def __init__(self) -> None: - self._miners: dict[str, JargonMiner] = {} + def _modify_jargon_entry(self, jargon_obj: MaiJargon) -> None: + with get_db_session() as session: + if not jargon_obj.item_id: + raise ValueError("jargon_obj must have item_id to update") + statement = select(Jargon).filter_by(id=jargon_obj.item_id).limit(1) + if db_record := session.exec(statement).first(): + db_record.is_jargon = jargon_obj.is_jargon + db_record.meaning = jargon_obj.meaning + db_record.last_inference_count = jargon_obj.last_inference_count + db_record.is_complete = jargon_obj.is_complete + session.add(db_record) - def get_miner(self, chat_id: str) -> JargonMiner: - if chat_id not in self._miners: - self._miners[chat_id] = JargonMiner(chat_id) - return self._miners[chat_id] + def _should_infer_meaning(self, jargon_obj: Jargon) -> bool: + """ + 判断是否需要进行含义推断 + 在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断 + 并且count必须大于last_inference_count,避免重启后重复判定 + 如果is_complete为True,不再进行推断 + """ + # 如果已完成所有推断,不再推断 + if jargon_obj.is_complete: + return False + count = jargon_obj.count or 0 + last_inference = jargon_obj.last_inference_count or 0 -miner_manager = JargonMinerManager() + # 阈值列表:3,6, 10, 20, 40, 60, 100 + thresholds = [2, 4, 8, 12, 24, 60, 100] + if count < thresholds[0]: + return False + # 如果count没有超过上次判定值,不需要判定 + if count <= last_inference: + return False -def search_jargon( - keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True -) -> List[Dict[str, str]]: - """ - 搜索jargon,支持大小写不敏感和模糊搜索 + next_threshold = next( + (threshold for threshold in thresholds if threshold > last_inference), + None, + ) + # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 + return False if next_threshold is None else count >= next_threshold - Args: - keyword: 搜索关键词 - chat_id: 可选的聊天ID - - 如果开启了all_global:此参数被忽略,查询所有is_global=True的记录 - - 如果关闭了all_global:如果提供则优先搜索该聊天或global的jargon - limit: 返回结果数量限制,默认10 - case_sensitive: 是否大小写敏感,默认False(不敏感) - fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) - - Returns: - List[Dict[str, str]]: 包含content, meaning的字典列表 - """ - if not keyword or not keyword.strip(): - return [] - - keyword = keyword.strip() - - # 构建查询(选择所有需要的字段,以便后续过滤) - query = Jargon.select() - - # 构建搜索条件 - if case_sensitive: - # 大小写敏感 - if fuzzy: - # 模糊搜索 - search_condition = Jargon.content.contains(keyword) - else: - # 精确匹配 - search_condition = Jargon.content == keyword - else: - # 大小写不敏感 - if fuzzy: - # 模糊搜索(使用LOWER函数) - search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) - else: - # 精确匹配(使用LOWER函数) - search_condition = fn.LOWER(Jargon.content) == keyword.lower() - - query = query.where(search_condition) - - # 根据all_global配置决定查询逻辑 - if global_config.expression.all_global_jargon: - # 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id) - query = query.where(Jargon.is_global) - # 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式 - - # 注意:meaning的过滤移到Python层面,因为我们需要先过滤chat_id - - # 按count降序排序,优先返回出现频率高的 - query = query.order_by(Jargon.count.desc()) - - # 限制结果数量(先多取一些,因为后面可能过滤) - query = query.limit(limit * 2) - - # 执行查询并返回结果,过滤chat_id - results = [] - for jargon in query: - # 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id - if chat_id and not global_config.expression.all_global_jargon: - chat_id_list = parse_chat_id_list(jargon.chat_id) - # 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含 - if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): - continue - - # 只返回有meaning的记录 - if not jargon.meaning or jargon.meaning.strip() == "": - continue - - results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) - - # 达到限制数量后停止 - if len(results) >= limit: - break - - return results + async def _infer_meaning_by_id(self, jargon_id: int): + jargon_obj: Optional[MaiJargon] = None + try: + with get_db_session() as session: + statement = select(Jargon).filter_by(id=jargon_id).limit(1) + if db_record := session.exec(statement).first(): + jargon_obj = MaiJargon.from_db_instance(db_record) + except Exception as e: + logger.error(f"查询Jargon id={jargon_id}失败: {e}") + return + if jargon_obj: + await self.infer_meaning(jargon_obj) diff --git a/src/bw_learner/jargon_miner_old.py b/src/bw_learner/jargon_miner_old.py new file mode 100644 index 00000000..d6495291 --- /dev/null +++ b/src/bw_learner/jargon_miner_old.py @@ -0,0 +1,589 @@ +import json +import asyncio +import random +from collections import OrderedDict +from typing import List, Dict, Optional, Callable +from json_repair import repair_json +from sqlalchemy import func as fn + +from src.common.logger import get_logger +from src.common.database.database_model import Jargon +from src.llm_models.utils_model import LLMRequest +from src.config.config import model_config, global_config +from src.chat.message_receive.chat_stream import get_chat_manager +from src.prompt.prompt_manager import prompt_manager +from src.bw_learner.learner_utils_old import ( + parse_chat_id_list, + chat_id_list_contains, + update_chat_id_list, +) + + +logger = get_logger("jargon") + + +def _is_single_char_jargon(content: str) -> bool: + """ + 判断是否是单字黑话(单个汉字、英文或数字) + + Args: + content: 词条内容 + + Returns: + bool: 如果是单字黑话返回True,否则返回False + """ + if not content or len(content) != 1: + return False + + char = content[0] + # 判断是否是单个汉字、单个英文字母或单个数字 + return ( + "\u4e00" <= char <= "\u9fff" # 汉字 + or "a" <= char <= "z" # 小写字母 + or "A" <= char <= "Z" # 大写字母 + or "0" <= char <= "9" # 数字 + ) + + +def _should_infer_meaning(jargon_obj: Jargon) -> bool: + """ + 判断是否需要进行含义推断 + 在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断 + 并且count必须大于last_inference_count,避免重启后重复判定 + 如果is_complete为True,不再进行推断 + """ + # 如果已完成所有推断,不再推断 + if jargon_obj.is_complete: + return False + + count = jargon_obj.count or 0 + last_inference = jargon_obj.last_inference_count or 0 + + # 阈值列表:3,6, 10, 20, 40, 60, 100 + thresholds = [2, 4, 8, 12, 24, 60, 100] + + if count < thresholds[0]: + return False + + # 如果count没有超过上次判定值,不需要判定 + if count <= last_inference: + return False + + # 找到第一个大于last_inference的阈值 + next_threshold = None + for threshold in thresholds: + if threshold > last_inference: + next_threshold = threshold + break + + # 如果没有找到下一个阈值,说明已经超过100,不应该再推断 + if next_threshold is None: + return False + + # 检查count是否达到或超过这个阈值 + return count >= next_threshold + + +class JargonMiner: + def __init__(self, chat_id: str) -> None: + self.chat_id = chat_id + + self.llm = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="jargon.extract", + ) + + self.llm_inference = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="jargon.inference", + ) + + # 初始化stream_name作为类属性,避免重复提取 + chat_manager = get_chat_manager() + stream_name = chat_manager.get_stream_name(self.chat_id) + self.stream_name = stream_name if stream_name else self.chat_id + self.cache_limit = 50 + self.cache: OrderedDict[str, None] = OrderedDict() + + # 黑话提取锁,防止并发执行 + self._extraction_lock = asyncio.Lock() + + def _add_to_cache(self, content: str) -> None: + """将提取到的黑话加入缓存,保持LRU语义""" + if not content: + return + + key = content.strip() + if not key: + return + + # 单字黑话(单个汉字、英文或数字)不记录到缓存 + if _is_single_char_jargon(key): + return + + if key in self.cache: + self.cache.move_to_end(key) + else: + self.cache[key] = None + if len(self.cache) > self.cache_limit: + self.cache.popitem(last=False) + + def get_cached_jargons(self) -> List[str]: + """获取缓存中的所有黑话列表""" + return list(self.cache.keys()) + + async def _infer_meaning_by_id(self, jargon_id: int) -> None: + """通过ID加载对象并推断""" + try: + jargon_obj = Jargon.get_by_id(jargon_id) + # 再次检查is_complete,因为可能在异步任务执行时已被标记为完成 + if jargon_obj.is_complete: + logger.debug(f"jargon {jargon_obj.content} 已完成所有推断,跳过") + return + await self.infer_meaning(jargon_obj) + except Exception as e: + logger.error(f"通过ID推断jargon失败: {e}") + + async def infer_meaning(self, jargon_obj: Jargon) -> None: + """ + 对jargon进行含义推断 + """ + try: + content = jargon_obj.content + raw_content_str = jargon_obj.raw_content or "" + + # 解析raw_content列表 + raw_content_list = [] + if raw_content_str: + try: + raw_content_list = ( + json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str + ) + if not isinstance(raw_content_list, list): + raw_content_list = [raw_content_list] if raw_content_list else [] + except (json.JSONDecodeError, TypeError): + raw_content_list = [raw_content_str] if raw_content_str else [] + + if not raw_content_list: + logger.warning(f"jargon {content} 没有raw_content,跳过推断") + return + + # 获取当前count和上一次的meaning + current_count = jargon_obj.count or 0 + previous_meaning = jargon_obj.meaning or "" + + # 当count为24, 60时,随机移除一半的raw_content项目 + if current_count in [24, 60] and len(raw_content_list) > 1: + # 计算要保留的数量(至少保留1个) + keep_count = max(1, len(raw_content_list) // 2) + raw_content_list = random.sample(raw_content_list, keep_count) + logger.info( + f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目" + ) + + # 步骤1: 基于raw_content和content推断 + raw_content_text = "\n".join(raw_content_list) + + # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 + previous_meaning_section = "" + previous_meaning_instruction = "" + if current_count in [24, 60, 100] and previous_meaning: + previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}" + previous_meaning_instruction = ( + "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" + ) + + prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context") + prompt1_template.add_context("bot_name", global_config.bot.nickname) + prompt1_template.add_context("content", str(content)) + prompt1_template.add_context("raw_content_list", raw_content_text) + prompt1_template.add_context("previous_meaning_section", previous_meaning_section) + prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction) + prompt1 = await prompt_manager.render_prompt(prompt1_template) + + response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3) + if not response1: + logger.warning(f"jargon {content} 推断1失败:无响应") + return + + # 解析推断1结果 + inference1 = None + try: + resp1 = response1.strip() + if resp1.startswith("{") and resp1.endswith("}"): + inference1 = json.loads(resp1) + else: + repaired = repair_json(resp1) + inference1 = json.loads(repaired) if isinstance(repaired, str) else repaired + if not isinstance(inference1, dict): + logger.warning(f"jargon {content} 推断1结果格式错误") + return + except Exception as e: + logger.error(f"jargon {content} 推断1解析失败: {e}") + return + + # 检查推断1是否表示信息不足无法推断 + no_info = inference1.get("no_info", False) + meaning1 = inference1.get("meaning", "").strip() + if no_info or not meaning1: + logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新") + # 更新最后一次判定的count值,避免在同一阈值重复尝试 + jargon_obj.last_inference_count = jargon_obj.count or 0 + jargon_obj.save() + return + + # 步骤2: 仅基于content推断 + prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only") + prompt2_template.add_context("content", str(content)) + prompt2 = await prompt_manager.render_prompt(prompt2_template) + + response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3) + if not response2: + logger.warning(f"jargon {content} 推断2失败:无响应") + return + + # 解析推断2结果 + inference2 = None + try: + resp2 = response2.strip() + if resp2.startswith("{") and resp2.endswith("}"): + inference2 = json.loads(resp2) + else: + repaired = repair_json(resp2) + inference2 = json.loads(repaired) if isinstance(repaired, str) else repaired + if not isinstance(inference2, dict): + logger.warning(f"jargon {content} 推断2结果格式错误") + return + except Exception as e: + logger.error(f"jargon {content} 推断2解析失败: {e}") + return + + # logger.info(f"jargon {content} 推断2提示词: {prompt2}") + # logger.info(f"jargon {content} 推断2结果: {response2}") + # logger.info(f"jargon {content} 推断1提示词: {prompt1}") + # logger.info(f"jargon {content} 推断1结果: {response1}") + + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 推断2提示词: {prompt2}") + logger.info(f"jargon {content} 推断2结果: {response2}") + logger.info(f"jargon {content} 推断1提示词: {prompt1}") + logger.info(f"jargon {content} 推断1结果: {response1}") + else: + logger.debug(f"jargon {content} 推断2提示词: {prompt2}") + logger.debug(f"jargon {content} 推断2结果: {response2}") + logger.debug(f"jargon {content} 推断1提示词: {prompt1}") + logger.debug(f"jargon {content} 推断1结果: {response1}") + + # 步骤3: 比较两个推断结果 + prompt3_template = prompt_manager.get_prompt("jargon_compare_inference") + prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False)) + prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False)) + prompt3 = await prompt_manager.render_prompt(prompt3_template) + + if global_config.debug.show_jargon_prompt: + logger.info(f"jargon {content} 比较提示词: {prompt3}") + + response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3) + if not response3: + logger.warning(f"jargon {content} 比较失败:无响应") + return + + # 解析比较结果 + comparison = None + try: + resp3 = response3.strip() + if resp3.startswith("{") and resp3.endswith("}"): + comparison = json.loads(resp3) + else: + repaired = repair_json(resp3) + comparison = json.loads(repaired) if isinstance(repaired, str) else repaired + if not isinstance(comparison, dict): + logger.warning(f"jargon {content} 比较结果格式错误") + return + except Exception as e: + logger.error(f"jargon {content} 比较解析失败: {e}") + return + + # 判断是否为黑话 + is_similar = comparison.get("is_similar", False) + is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话 + + # 更新数据库记录 + jargon_obj.is_jargon = is_jargon + jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else "" + # 更新最后一次判定的count值,避免重启后重复判定 + jargon_obj.last_inference_count = jargon_obj.count or 0 + + # 如果count>=100,标记为完成,不再进行推断 + if (jargon_obj.count or 0) >= 100: + jargon_obj.is_complete = True + + jargon_obj.save() + logger.debug( + f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}" + ) + + # 固定输出推断结果,格式化为可读形式 + if is_jargon: + # 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx + meaning = jargon_obj.meaning or "无详细说明" + is_global = jargon_obj.is_global + if is_global: + logger.info(f"[黑话]{content}的含义是 {meaning}") + else: + logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}") + else: + # 不是黑话,输出格式:[聊天名]xxx 不是黑话 + logger.info(f"[{self.stream_name}]{content} 不是黑话") + + except Exception as e: + logger.error(f"jargon推断失败: {e}") + import traceback + + traceback.print_exc() + + async def process_extracted_entries( + self, entries: List[Dict[str, List[str]]], person_name_filter: Optional[Callable[[str], bool]] = None + ) -> None: + """ + 处理已提取的黑话条目(从 expression_learner 路由过来的) + + Args: + entries: 黑话条目列表,每个元素格式为 {"content": "...", "raw_content": [...]} + person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称 + """ + if not entries: + return + + try: + # 去重并合并raw_content(按 content 聚合) + merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict() + for entry in entries: + content_key = entry["content"] + + # 检查是否包含人物名称 + # logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}") + # logger.info(f"person_name_filter: {person_name_filter}") + if person_name_filter and person_name_filter(content_key): + logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}") + continue + + raw_list = entry.get("raw_content", []) or [] + if content_key in merged_entries: + merged_entries[content_key]["raw_content"].extend(raw_list) + else: + merged_entries[content_key] = { + "content": content_key, + "raw_content": list(raw_list), + } + + uniq_entries = [] + for merged_entry in merged_entries.values(): + raw_content_list = merged_entry["raw_content"] + if raw_content_list: + merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list)) + uniq_entries.append(merged_entry) + + saved = 0 + updated = 0 + for entry in uniq_entries: + content = entry["content"] + raw_content_list = entry["raw_content"] # 已经是列表 + + try: + # 查询所有content匹配的记录 + query = Jargon.select().where(Jargon.content == content) + + # 查找匹配的记录 + matched_obj = None + for obj in query: + if global_config.expression.all_global_jargon: + # 开启all_global:所有content匹配的记录都可以 + matched_obj = obj + break + else: + # 关闭all_global:需要检查chat_id列表是否包含目标chat_id + chat_id_list = parse_chat_id_list(obj.chat_id) + if chat_id_list_contains(chat_id_list, self.chat_id): + matched_obj = obj + break + + if matched_obj: + obj = matched_obj + try: + obj.count = (obj.count or 0) + 1 + except Exception: + obj.count = 1 + + # 合并raw_content列表:读取现有列表,追加新值,去重 + existing_raw_content = [] + if obj.raw_content: + try: + existing_raw_content = ( + json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content + ) + if not isinstance(existing_raw_content, list): + existing_raw_content = [existing_raw_content] if existing_raw_content else [] + except (json.JSONDecodeError, TypeError): + existing_raw_content = [obj.raw_content] if obj.raw_content else [] + + # 合并并去重 + merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list)) + obj.raw_content = json.dumps(merged_list, ensure_ascii=False) + + # 更新chat_id列表:增加当前chat_id的计数 + chat_id_list = parse_chat_id_list(obj.chat_id) + updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1) + obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False) + + # 开启all_global时,确保记录标记为is_global=True + if global_config.expression.all_global_jargon: + obj.is_global = True + # 关闭all_global时,保持原有is_global不变(不修改) + + obj.save() + + # 检查是否需要推断(达到阈值且超过上次判定值) + if _should_infer_meaning(obj): + # 异步触发推断,不阻塞主流程 + # 重新加载对象以确保数据最新 + jargon_id = obj.id + asyncio.create_task(self._infer_meaning_by_id(jargon_id)) + + updated += 1 + else: + # 没找到匹配记录,创建新记录 + if global_config.expression.all_global_jargon: + # 开启all_global:新记录默认为is_global=True + is_global_new = True + else: + # 关闭all_global:新记录is_global=False + is_global_new = False + + # 使用新格式创建chat_id列表:[[chat_id, count]] + chat_id_list = [[self.chat_id, 1]] + chat_id_json = json.dumps(chat_id_list, ensure_ascii=False) + + Jargon.create( + content=content, + raw_content=json.dumps(raw_content_list, ensure_ascii=False), + chat_id=chat_id_json, + is_global=is_global_new, + count=1, + ) + saved += 1 + except Exception as e: + logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}") + continue + finally: + self._add_to_cache(content) + + # 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出) + if uniq_entries: + # 收集所有提取的jargon内容 + jargon_list = [entry["content"] for entry in uniq_entries] + jargon_str = ",".join(jargon_list) + + # 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色) + logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}") + + if saved or updated: + logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}") + except Exception as e: + logger.error(f"处理已提取的黑话条目失败: {e}") + + +class JargonMinerManager: + def __init__(self) -> None: + self._miners: dict[str, JargonMiner] = {} + + def get_miner(self, chat_id: str) -> JargonMiner: + if chat_id not in self._miners: + self._miners[chat_id] = JargonMiner(chat_id) + return self._miners[chat_id] + + +miner_manager = JargonMinerManager() + + +def search_jargon( + keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True +) -> List[Dict[str, str]]: + """ + 搜索jargon,支持大小写不敏感和模糊搜索 + + Args: + keyword: 搜索关键词 + chat_id: 可选的聊天ID + - 如果开启了all_global:此参数被忽略,查询所有is_global=True的记录 + - 如果关闭了all_global:如果提供则优先搜索该聊天或global的jargon + limit: 返回结果数量限制,默认10 + case_sensitive: 是否大小写敏感,默认False(不敏感) + fuzzy: 是否模糊搜索,默认True(使用LIKE匹配) + + Returns: + List[Dict[str, str]]: 包含content, meaning的字典列表 + """ + if not keyword or not keyword.strip(): + return [] + + keyword = keyword.strip() + + # 构建查询(选择所有需要的字段,以便后续过滤) + query = Jargon.select() + + # 构建搜索条件 + if case_sensitive: + # 大小写敏感 + if fuzzy: + # 模糊搜索 + search_condition = Jargon.content.contains(keyword) + else: + # 精确匹配 + search_condition = Jargon.content == keyword + else: + # 大小写不敏感 + if fuzzy: + # 模糊搜索(使用LOWER函数) + search_condition = fn.LOWER(Jargon.content).contains(keyword.lower()) + else: + # 精确匹配(使用LOWER函数) + search_condition = fn.LOWER(Jargon.content) == keyword.lower() + + query = query.where(search_condition) + + # 根据all_global配置决定查询逻辑 + if global_config.expression.all_global_jargon: + # 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id) + query = query.where(Jargon.is_global) + # 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式 + + # 注意:meaning的过滤移到Python层面,因为我们需要先过滤chat_id + + # 按count降序排序,优先返回出现频率高的 + query = query.order_by(Jargon.count.desc()) + + # 限制结果数量(先多取一些,因为后面可能过滤) + query = query.limit(limit * 2) + + # 执行查询并返回结果,过滤chat_id + results = [] + for jargon in query: + # 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id + if chat_id and not global_config.expression.all_global_jargon: + chat_id_list = parse_chat_id_list(jargon.chat_id) + # 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含 + if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): + continue + + # 只返回有meaning的记录 + if not jargon.meaning or jargon.meaning.strip() == "": + continue + + results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) + + # 达到限制数量后停止 + if len(results) >= limit: + break + + return results diff --git a/src/bw_learner/learner_utils.py b/src/bw_learner/learner_utils.py index ce3ea379..80849f02 100644 --- a/src/bw_learner/learner_utils.py +++ b/src/bw_learner/learner_utils.py @@ -1,355 +1,48 @@ +from json_repair import repair_json +from typing import List, Tuple + import re -import difflib -import random import json -from typing import Optional, List, Dict, Any, Tuple from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - build_readable_messages, -) -from src.chat.utils.utils import parse_platform_accounts -from json_repair import repair_json - logger = get_logger("learner_utils") -def filter_message_content(content: Optional[str]) -> str: - """ - 过滤消息内容,移除回复、@、图片等格式 +def fix_chinese_quotes_in_json(text): + """使用状态机修复 JSON 字符串值中的中文引号""" + result = [] + i = 0 + in_string = False + escape_next = False - Args: - content: 原始消息内容 - - Returns: - str: 过滤后的内容 - """ - if not content: - return "" - - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r"\[回复.*?\],说:\s*", "", content) - # 移除@<...>格式的内容 - content = re.sub(r"@<[^>]*>", "", content) - # 移除[picid:...]格式的图片ID - content = re.sub(r"\[picid:[^\]]*\]", "", content) - # 移除[表情包:...]格式的内容 - content = re.sub(r"\[表情包:[^\]]*\]", "", content) - - return content.strip() - - -def calculate_similarity(text1: str, text2: str) -> float: - """ - 计算两个文本的相似度,返回0-1之间的值 - 使用SequenceMatcher计算相似度 - - Args: - text1: 第一个文本 - text2: 第二个文本 - - Returns: - float: 相似度值,范围0-1 - """ - return difflib.SequenceMatcher(None, text1, text2).ratio() - - -def calculate_style_similarity(style1: str, style2: str) -> float: - """ - 计算两个 style 的相似度,返回0-1之间的值 - 在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py) - - Args: - style1: 第一个 style - style2: 第二个 style - - Returns: - float: 相似度值,范围0-1 - """ - if not style1 or not style2: - return 0.0 - - # 移除"使用"和"句式"这两个词 - def remove_ignored_words(text: str) -> str: - """移除需要忽略的词""" - text = text.replace("使用", "") - text = text.replace("句式", "") - return text.strip() - - cleaned_style1 = remove_ignored_words(style1) - cleaned_style2 = remove_ignored_words(style2) - - # 如果清理后文本为空,返回0 - if not cleaned_style1 or not cleaned_style2: - return 0.0 - - return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() - - -def _compute_weights(population: List[Dict]) -> List[float]: - """ - 根据表达的count计算权重,范围限定在1~5之间。 - count越高,权重越高,但最多为基础权重的5倍。 - """ - if not population: - return [] - - counts = [] - for item in population: - count = item.get("count", 1) - try: - count_value = float(count) - except (TypeError, ValueError): - count_value = 1.0 - counts.append(max(count_value, 0.0)) - - min_count = min(counts) - max_count = max(counts) - - if max_count == min_count: - weights = [1.0 for _ in counts] - else: - weights = [] - for count_value in counts: - # 线性映射到[1,5]区间 - normalized = (count_value - min_count) / (max_count - min_count) - weights.append(1.0 + normalized * 4.0) # 1~5 - - return weights - - -def weighted_sample(population: List[Dict], k: int) -> List[Dict]: - """ - 随机抽样函数 - - Args: - population: 总体数据列表 - k: 需要抽取的数量 - - Returns: - List[Dict]: 抽取的数据列表 - """ - if not population or k <= 0: - return [] - - if len(population) <= k: - return population.copy() - - selected: List[Dict] = [] - population_copy = population.copy() - - for _ in range(min(k, len(population_copy))): - weights = _compute_weights(population_copy) - total_weight = sum(weights) - if total_weight <= 0: - # 回退到均匀随机 - idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(idx)) + while i < len(text): + char = text[i] + if escape_next: + # 当前字符是转义字符后的字符,直接添加 + result.append(char) + escape_next = False + i += 1 continue - - threshold = random.uniform(0, total_weight) - cumulative = 0.0 - for idx, weight in enumerate(weights): - cumulative += weight - if threshold <= cumulative: - selected.append(population_copy.pop(idx)) - break - - return selected - - -def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: - """ - 解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表) - - Args: - chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式) - - Returns: - List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表 - """ - if not chat_id_value: - return [] - - # 如果是字符串,尝试解析为JSON - if isinstance(chat_id_value, str): - # 尝试解析JSON - try: - parsed = json.loads(chat_id_value) - if isinstance(parsed, list): - # 新格式:已经是列表 - return parsed - elif isinstance(parsed, str): - # 解析后还是字符串,说明是旧格式 - return [[parsed, 1]] - else: - # 其他类型,当作旧格式处理 - return [[str(chat_id_value), 1]] - except (json.JSONDecodeError, TypeError): - # 解析失败,当作旧格式(纯字符串) - return [[str(chat_id_value), 1]] - elif isinstance(chat_id_value, list): - # 已经是列表格式 - return chat_id_value - else: - # 其他类型,转换为旧格式 - return [[str(chat_id_value), 1]] - - -def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]: - """ - 更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目 - - Args: - chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...] - target_chat_id: 要更新或添加的chat_id - increment: 增加的计数,默认为1 - - Returns: - List[List[Any]]: 更新后的chat_id列表 - """ - item = _find_chat_id_item(chat_id_list, target_chat_id) - if item is not None: - # 找到匹配的chat_id,增加计数 - if len(item) >= 2: - item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment + if char == "\\": + # 转义字符 + result.append(char) + escape_next = True + i += 1 + continue + if char == '"' and not escape_next: + # 遇到英文引号,切换字符串状态 + in_string = not in_string + result.append(char) + i += 1 + continue + if in_string and char in ["“", "”"]: + result.append('\\"') else: - item.append(increment) - else: - # 未找到,添加新条目 - chat_id_list.append([target_chat_id, increment]) + result.append(char) + i += 1 - return chat_id_list - - -def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]: - """ - 在chat_id列表中查找匹配的项(辅助函数) - - Args: - chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] - target_chat_id: 要查找的chat_id - - Returns: - 如果找到则返回匹配的项,否则返回None - """ - for item in chat_id_list: - if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): - return item - return None - - -def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: - """ - 检查chat_id列表中是否包含指定的chat_id - - Args: - chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] - target_chat_id: 要查找的chat_id - - Returns: - bool: 如果包含则返回True - """ - return _find_chat_id_item(chat_id_list, target_chat_id) is not None - - -def contains_bot_self_name(content: str) -> bool: - """ - 判断词条是否包含机器人的昵称或别名 - """ - if not content: - return False - - bot_config = getattr(global_config, "bot", None) - if not bot_config: - return False - - target = content.strip().lower() - nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower() - alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []] - - candidates = [name for name in [nickname, *alias_names] if name] - - return any(name in target for name in candidates) - - -def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: - """ - 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 - """ - if not messages or center_index < 0 or center_index >= len(messages): - return None - - context_start = max(0, center_index - 3) - context_end = min(len(messages), center_index + 1 + 3) - context_messages = messages[context_start:context_end] - - if not context_messages: - return None - - try: - paragraph = build_readable_messages( - messages=context_messages, - replace_bot_name=True, - timestamp_mode="relative", - read_mark=0.0, - truncate=False, - show_actions=False, - show_pic=True, - message_id_list=None, - remove_emoji_stickers=False, - pic_single=True, - ) - except Exception as e: - logger.warning(f"构建上下文段落失败: {e}") - return None - - paragraph = paragraph.strip() - return paragraph or None - - -def is_bot_message(msg: Any) -> bool: - """判断消息是否来自机器人自身""" - if msg is None: - return False - - bot_config = getattr(global_config, "bot", None) - if not bot_config: - return False - - platform = ( - str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "") - .strip() - .lower() - ) - user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() - - if not platform or not user_id: - return False - - platform_accounts = {} - try: - platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or []) - except Exception: - platform_accounts = {} - - bot_accounts: Dict[str, str] = {} - qq_account = str(getattr(bot_config, "qq_account", "") or "").strip() - if qq_account: - bot_accounts["qq"] = qq_account - - telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip() - if telegram_account: - bot_accounts["telegram"] = telegram_account - - for plat, account in platform_accounts.items(): - if account and plat not in bot_accounts: - bot_accounts[plat] = account - - bot_account = bot_accounts.get(platform) - return bool(bot_account and user_id == bot_account) + return "".join(result) def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: @@ -373,11 +66,8 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]] raw = response.strip() - # 尝试提取 ```json 代码块 - json_block_pattern = r"```json\s*(.*?)\s*```" - match = re.search(json_block_pattern, raw, re.DOTALL) - if match: - raw = match.group(1).strip() + if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL): + raw = match[1].strip() else: # 去掉可能存在的通用 ``` 包裹 raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) @@ -394,62 +84,11 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]] parsed = json.loads(raw) else: repaired = repair_json(raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired + parsed = json.loads(repaired) if isinstance(repaired, str) else repaired except Exception as parse_error: # 如果解析失败,尝试修复中文引号问题 # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 try: - - def fix_chinese_quotes_in_json(text): - """使用状态机修复 JSON 字符串值中的中文引号""" - result = [] - i = 0 - in_string = False - escape_next = False - - while i < len(text): - char = text[i] - - if escape_next: - # 当前字符是转义字符后的字符,直接添加 - result.append(char) - escape_next = False - i += 1 - continue - - if char == "\\": - # 转义字符 - result.append(char) - escape_next = True - i += 1 - continue - - if char == '"' and not escape_next: - # 遇到英文引号,切换字符串状态 - in_string = not in_string - result.append(char) - i += 1 - continue - - if in_string: - # 在字符串值内部,将中文引号替换为转义的英文引号 - if char == '"': # 中文左引号 U+201C - result.append('\\"') - elif char == '"': # 中文右引号 U+201D - result.append('\\"') - else: - result.append(char) - else: - # 不在字符串内,直接添加 - result.append(char) - - i += 1 - - return "".join(result) - fixed_raw = fix_chinese_quotes_in_json(raw) # 再次尝试解析 @@ -457,10 +96,7 @@ def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]] parsed = json.loads(fixed_raw) else: repaired = repair_json(fixed_raw) - if isinstance(repaired, str): - parsed = json.loads(repaired) - else: - parsed = repaired + parsed = json.loads(repaired) if isinstance(repaired, str) else repaired except Exception as fix_error: logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") diff --git a/src/bw_learner/learner_utils_old.py b/src/bw_learner/learner_utils_old.py new file mode 100644 index 00000000..ce3ea379 --- /dev/null +++ b/src/bw_learner/learner_utils_old.py @@ -0,0 +1,498 @@ +import re +import difflib +import random +import json +from typing import Optional, List, Dict, Any, Tuple + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.utils.chat_message_builder import ( + build_readable_messages, +) +from src.chat.utils.utils import parse_platform_accounts +from json_repair import repair_json + + +logger = get_logger("learner_utils") + + +def filter_message_content(content: Optional[str]) -> str: + """ + 过滤消息内容,移除回复、@、图片等格式 + + Args: + content: 原始消息内容 + + Returns: + str: 过滤后的内容 + """ + if not content: + return "" + + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 + content = re.sub(r"\[回复.*?\],说:\s*", "", content) + # 移除@<...>格式的内容 + content = re.sub(r"@<[^>]*>", "", content) + # 移除[picid:...]格式的图片ID + content = re.sub(r"\[picid:[^\]]*\]", "", content) + # 移除[表情包:...]格式的内容 + content = re.sub(r"\[表情包:[^\]]*\]", "", content) + + return content.strip() + + +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + 使用SequenceMatcher计算相似度 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + float: 相似度值,范围0-1 + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + +def calculate_style_similarity(style1: str, style2: str) -> float: + """ + 计算两个 style 的相似度,返回0-1之间的值 + 在计算前会移除"使用"和"句式"这两个词(参考 expression_similarity_analysis.py) + + Args: + style1: 第一个 style + style2: 第二个 style + + Returns: + float: 相似度值,范围0-1 + """ + if not style1 or not style2: + return 0.0 + + # 移除"使用"和"句式"这两个词 + def remove_ignored_words(text: str) -> str: + """移除需要忽略的词""" + text = text.replace("使用", "") + text = text.replace("句式", "") + return text.strip() + + cleaned_style1 = remove_ignored_words(style1) + cleaned_style2 = remove_ignored_words(style2) + + # 如果清理后文本为空,返回0 + if not cleaned_style1 or not cleaned_style2: + return 0.0 + + return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio() + + +def _compute_weights(population: List[Dict]) -> List[float]: + """ + 根据表达的count计算权重,范围限定在1~5之间。 + count越高,权重越高,但最多为基础权重的5倍。 + """ + if not population: + return [] + + counts = [] + for item in population: + count = item.get("count", 1) + try: + count_value = float(count) + except (TypeError, ValueError): + count_value = 1.0 + counts.append(max(count_value, 0.0)) + + min_count = min(counts) + max_count = max(counts) + + if max_count == min_count: + weights = [1.0 for _ in counts] + else: + weights = [] + for count_value in counts: + # 线性映射到[1,5]区间 + normalized = (count_value - min_count) / (max_count - min_count) + weights.append(1.0 + normalized * 4.0) # 1~5 + + return weights + + +def weighted_sample(population: List[Dict], k: int) -> List[Dict]: + """ + 随机抽样函数 + + Args: + population: 总体数据列表 + k: 需要抽取的数量 + + Returns: + List[Dict]: 抽取的数据列表 + """ + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + selected: List[Dict] = [] + population_copy = population.copy() + + for _ in range(min(k, len(population_copy))): + weights = _compute_weights(population_copy) + total_weight = sum(weights) + if total_weight <= 0: + # 回退到均匀随机 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + continue + + threshold = random.uniform(0, total_weight) + cumulative = 0.0 + for idx, weight in enumerate(weights): + cumulative += weight + if threshold <= cumulative: + selected.append(population_copy.pop(idx)) + break + + return selected + + +def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: + """ + 解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表) + + Args: + chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式) + + Returns: + List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表 + """ + if not chat_id_value: + return [] + + # 如果是字符串,尝试解析为JSON + if isinstance(chat_id_value, str): + # 尝试解析JSON + try: + parsed = json.loads(chat_id_value) + if isinstance(parsed, list): + # 新格式:已经是列表 + return parsed + elif isinstance(parsed, str): + # 解析后还是字符串,说明是旧格式 + return [[parsed, 1]] + else: + # 其他类型,当作旧格式处理 + return [[str(chat_id_value), 1]] + except (json.JSONDecodeError, TypeError): + # 解析失败,当作旧格式(纯字符串) + return [[str(chat_id_value), 1]] + elif isinstance(chat_id_value, list): + # 已经是列表格式 + return chat_id_value + else: + # 其他类型,转换为旧格式 + return [[str(chat_id_value), 1]] + + +def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]: + """ + 更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目 + + Args: + chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...] + target_chat_id: 要更新或添加的chat_id + increment: 增加的计数,默认为1 + + Returns: + List[List[Any]]: 更新后的chat_id列表 + """ + item = _find_chat_id_item(chat_id_list, target_chat_id) + if item is not None: + # 找到匹配的chat_id,增加计数 + if len(item) >= 2: + item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment + else: + item.append(increment) + else: + # 未找到,添加新条目 + chat_id_list.append([target_chat_id, increment]) + + return chat_id_list + + +def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]: + """ + 在chat_id列表中查找匹配的项(辅助函数) + + Args: + chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] + target_chat_id: 要查找的chat_id + + Returns: + 如果找到则返回匹配的项,否则返回None + """ + for item in chat_id_list: + if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): + return item + return None + + +def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: + """ + 检查chat_id列表中是否包含指定的chat_id + + Args: + chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] + target_chat_id: 要查找的chat_id + + Returns: + bool: 如果包含则返回True + """ + return _find_chat_id_item(chat_id_list, target_chat_id) is not None + + +def contains_bot_self_name(content: str) -> bool: + """ + 判断词条是否包含机器人的昵称或别名 + """ + if not content: + return False + + bot_config = getattr(global_config, "bot", None) + if not bot_config: + return False + + target = content.strip().lower() + nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower() + alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []] + + candidates = [name for name in [nickname, *alias_names] if name] + + return any(name in target for name in candidates) + + +def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: + """ + 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 + """ + if not messages or center_index < 0 or center_index >= len(messages): + return None + + context_start = max(0, center_index - 3) + context_end = min(len(messages), center_index + 1 + 3) + context_messages = messages[context_start:context_end] + + if not context_messages: + return None + + try: + paragraph = build_readable_messages( + messages=context_messages, + replace_bot_name=True, + timestamp_mode="relative", + read_mark=0.0, + truncate=False, + show_actions=False, + show_pic=True, + message_id_list=None, + remove_emoji_stickers=False, + pic_single=True, + ) + except Exception as e: + logger.warning(f"构建上下文段落失败: {e}") + return None + + paragraph = paragraph.strip() + return paragraph or None + + +def is_bot_message(msg: Any) -> bool: + """判断消息是否来自机器人自身""" + if msg is None: + return False + + bot_config = getattr(global_config, "bot", None) + if not bot_config: + return False + + platform = ( + str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "") + .strip() + .lower() + ) + user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() + + if not platform or not user_id: + return False + + platform_accounts = {} + try: + platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or []) + except Exception: + platform_accounts = {} + + bot_accounts: Dict[str, str] = {} + qq_account = str(getattr(bot_config, "qq_account", "") or "").strip() + if qq_account: + bot_accounts["qq"] = qq_account + + telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip() + if telegram_account: + bot_accounts["telegram"] = telegram_account + + for plat, account in platform_accounts.items(): + if account and plat not in bot_accounts: + bot_accounts[plat] = account + + bot_account = bot_accounts.get(platform) + return bool(bot_account and user_id == bot_account) + + +def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + """ + 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 + + 期望的 JSON 结构: + [ + {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 + {"content": "词条", "source_id": "12"}, // 黑话 + ... + ] + + Returns: + Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: + 第一个列表是表达方式 (situation, style, source_id) + 第二个列表是黑话 (content, source_id) + """ + if not response: + return [], [] + + raw = response.strip() + + # 尝试提取 ```json 代码块 + json_block_pattern = r"```json\s*(.*?)\s*```" + match = re.search(json_block_pattern, raw, re.DOTALL) + if match: + raw = match.group(1).strip() + else: + # 去掉可能存在的通用 ``` 包裹 + raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) + raw = raw.strip() + + parsed = None + expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) + jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) + + try: + # 优先尝试直接解析 + if raw.startswith("[") and raw.endswith("]"): + parsed = json.loads(raw) + else: + repaired = repair_json(raw) + if isinstance(repaired, str): + parsed = json.loads(repaired) + else: + parsed = repaired + except Exception as parse_error: + # 如果解析失败,尝试修复中文引号问题 + # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 + try: + + def fix_chinese_quotes_in_json(text): + """使用状态机修复 JSON 字符串值中的中文引号""" + result = [] + i = 0 + in_string = False + escape_next = False + + while i < len(text): + char = text[i] + + if escape_next: + # 当前字符是转义字符后的字符,直接添加 + result.append(char) + escape_next = False + i += 1 + continue + + if char == "\\": + # 转义字符 + result.append(char) + escape_next = True + i += 1 + continue + + if char == '"' and not escape_next: + # 遇到英文引号,切换字符串状态 + in_string = not in_string + result.append(char) + i += 1 + continue + + if in_string: + # 在字符串值内部,将中文引号替换为转义的英文引号 + if char == '"': # 中文左引号 U+201C + result.append('\\"') + elif char == '"': # 中文右引号 U+201D + result.append('\\"') + else: + result.append(char) + else: + # 不在字符串内,直接添加 + result.append(char) + + i += 1 + + return "".join(result) + + fixed_raw = fix_chinese_quotes_in_json(raw) + + # 再次尝试解析 + if fixed_raw.startswith("[") and fixed_raw.endswith("]"): + parsed = json.loads(fixed_raw) + else: + repaired = repair_json(fixed_raw) + if isinstance(repaired, str): + parsed = json.loads(repaired) + else: + parsed = repaired + except Exception as fix_error: + logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") + logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") + logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") + logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") + return [], [] + + if isinstance(parsed, dict): + parsed_list = [parsed] + elif isinstance(parsed, list): + parsed_list = parsed + else: + logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") + return [], [] + + for item in parsed_list: + if not isinstance(item, dict): + continue + + # 检查是否是表达方式条目(有 situation 和 style) + situation = str(item.get("situation", "")).strip() + style = str(item.get("style", "")).strip() + source_id = str(item.get("source_id", "")).strip() + + if situation and style and source_id: + # 表达方式条目 + expressions.append((situation, style, source_id)) + elif item.get("content"): + # 黑话条目(有 content 字段) + content = str(item.get("content", "")).strip() + source_id = str(item.get("source_id", "")).strip() + if content and source_id: + jargon_entries.append((content, source_id)) + + return expressions, jargon_entries diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder_old.py similarity index 97% rename from src/bw_learner/message_recorder.py rename to src/bw_learner/message_recorder_old.py index bdf13fed..a5d90cc9 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder_old.py @@ -5,8 +5,8 @@ from src.common.logger import get_logger from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.common_utils import TempMethodsExpression -from src.bw_learner.expression_learner import expression_learner_manager -from src.bw_learner.jargon_miner import miner_manager +from src.bw_learner.expression_learner_old import expression_learner_manager +from src.bw_learner.jargon_miner_old import miner_manager logger = get_logger("bw_learner") diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index befa1675..ce329d87 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -16,8 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail -from src.bw_learner.expression_learner import expression_learner_manager -from src.bw_learner.message_recorder import extract_and_distribute_messages +from src.bw_learner.expression_learner_old import expression_learner_manager +from src.bw_learner.message_recorder_old import extract_and_distribute_messages from src.person_info.person_info import Person from src.core.types import ActionInfo, EventType from src.core.event_bus import event_bus @@ -63,7 +63,7 @@ class BrainChatting: 用于在特定聊天流中生成回复。 """ - def __init__(self, chat_id: str): + def __init__(self, session_id: str): """ BrainChatting 初始化函数 @@ -73,8 +73,8 @@ class BrainChatting: performance_version: 性能记录版本号,用于区分不同启动版本 """ # 基础属性 - self.stream_id: str = chat_id # 聊天流ID - self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore + self.stream_id: str = session_id # 聊天流ID + self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore if not self.chat_stream: raise ValueError(f"无法找到聊天流: {self.stream_id}") self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]" @@ -269,7 +269,7 @@ class BrainChatting: # Expression Reflection Check # 检查是否需要提问表达反思 # ------------------------------------------------------------------------- - from src.bw_learner.expression_reflector import expression_reflector_manager + from src.bw_learner.expression_reflector_old import expression_reflector_manager reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) asyncio.create_task(reflector.check_and_ask()) diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index ace6b8b3..9f1c1da1 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -10,6 +10,7 @@ from src.common.logger import get_logger from src.common.utils.utils_session import SessionUtils from src.config.config import global_config from src.chat.message_receive.chat_manager import chat_manager +from src.bw_learner.expression_reflector import ExpressionReflector if TYPE_CHECKING: from src.chat.message_receive.message import SessionMessage @@ -52,6 +53,9 @@ class HeartFChatting: # Asyncio Event 用于控制循环的开始和结束 self._cycle_event = asyncio.Event() + # 反思器 + self.reflector = ExpressionReflector(session_id) + async def start(self): """启动 HeartFChatting 的主循环""" # 先检查是否已经启动运行 @@ -160,7 +164,12 @@ class HeartFChatting: async def _judge_and_response(self, mentioned_message: Optional["SessionMessage"] = None): """判定和生成回复""" - # TODO: 在expression和reflector重构完成后完成这里的逻辑 + await self.reflector.check_and_ask() + if self.reflector.reflect_tracker.tracking and await self.reflector.reflect_tracker.trigger_tracker(): + logger.info(f"{self.log_prefix} 追踪检查已解决,结束追踪器") + self.reflector.reflect_tracker.reset_tracker() # 结束当前追踪器 + + # TODO: 完成反思器之后的逻辑 def _handle_loop_completion(self, task: asyncio.Task): """当 _hfc_func 任务完成时执行的回调。""" diff --git a/src/chat/heart_flow/heartFC_chat_old.py b/src/chat/heart_flow/heartFC_chat_old.py new file mode 100644 index 00000000..2842758f --- /dev/null +++ b/src/chat/heart_flow/heartFC_chat_old.py @@ -0,0 +1,814 @@ +import asyncio +import time +import traceback +import random +from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING +from rich.traceback import install + +from src.config.config import global_config +from src.common.logger import get_logger +from src.common.data_models.info_data_model import ActionPlannerInfo +from src.common.data_models.message_data_model import ReplyContentType +from src.chat.message_receive.chat_manager import chat_manager, BotChatSession +from src.chat.utils.prompt_builder import global_prompt_manager +from src.chat.utils.timer_calculator import Timer +from src.chat.planner_actions.planner import ActionPlanner +from src.chat.planner_actions.action_modifier import ActionModifier +from src.chat.planner_actions.action_manager import ActionManager +from src.chat.heart_flow.hfc_utils import CycleDetail +from src.bw_learner.expression_learner_old import expression_learner_manager +from src.chat.heart_flow.frequency_control import frequency_control_manager +from src.bw_learner.reflect_tracker import reflect_tracker_manager +from src.bw_learner.expression_reflector_old import expression_reflector_manager +from src.bw_learner.message_recorder_old import extract_and_distribute_messages +from src.person_info.person_info import Person +from src.plugin_system.base.component_types import EventType, ActionInfo +from src.plugin_system.core import events_manager +from src.plugin_system.apis import generator_api, send_api, message_api, database_api +from src.chat.utils.chat_message_builder import ( + build_readable_messages_with_id, + get_raw_msg_before_timestamp_with_chat, +) +from src.chat.utils.utils import record_replyer_action_temp +from src.memory_system.chat_history_summarizer import ChatHistorySummarizer + +if TYPE_CHECKING: + from src.common.data_models.database_data_model import DatabaseMessages + from src.common.data_models.message_data_model import ReplySetModel + + +ERROR_LOOP_INFO = { + "loop_plan_info": { + "action_result": { + "action_type": "error", + "action_data": {}, + "reasoning": "循环处理失败", + }, + }, + "loop_action_info": { + "action_taken": False, + "reply_text": "", + "command": "", + "taken_time": time.time(), + }, +} + + +install(extra_lines=3) + +# 注释:原来的动作修改超时常量已移除,因为改为顺序执行 + +logger = get_logger("hfc") # Logger Name Changed + + +class HeartFChatting: + """ + 管理一个连续的Focus Chat循环 + 用于在特定聊天流中生成回复。 + 其生命周期现在由其关联的 SubHeartflow 的 FOCUSED 状态控制。 + """ + + def __init__(self, session_id: str): + """ + HeartFChatting 初始化函数 + + 参数: + session_id: 聊天会话唯一标识符(如session_id) + on_stop_focus_chat: 当收到stop_focus_chat命令时调用的回调函数 + performance_version: 性能记录版本号,用于区分不同启动版本 + """ + # 基础属性 + self.session_id: str = session_id # 聊天会话ID + session = chat_manager.get_session_by_session_id(session_id) + if not session: + raise ValueError(f"未找到 session_id={session_id} 的聊天会话") + self.chat_session: BotChatSession = session + self.log_prefix = f"[{chat_manager.get_session_name(self.session_id) or self.session_id}]" + + self.expression_learner = expression_learner_manager.get_expression_learner(self.session_id) + + self.action_manager = ActionManager() + self.action_planner = ActionPlanner(chat_id=self.session_id, action_manager=self.action_manager) + self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.session_id) + + # 循环控制内部状态 + self.running: bool = False + self._loop_task: Optional[asyncio.Task] = None # 主循环任务 + + # 添加循环信息管理相关的属性 + self.history_loop: List[CycleDetail] = [] + self._cycle_counter = 0 + self._current_cycle_detail: CycleDetail = None # type: ignore + + self.last_read_time = time.time() - 2 + + self.is_mute = False + + self.last_active_time = time.time() # 记录上一次非noreply时间 + + self.question_probability_multiplier = 1 + self.questioned = False + + # 跟踪连续 no_reply 次数,用于动态调整阈值 + self.consecutive_no_reply_count = 0 + + # 聊天内容概括器 + self.chat_history_summarizer = ChatHistorySummarizer(session_id=self.session_id) + + async def start(self): + """检查是否需要启动主循环,如果未激活则启动。""" + + # 如果循环已经激活,直接返回 + if self.running: + logger.debug(f"{self.log_prefix} HeartFChatting 已激活,无需重复启动") + return + + try: + # 标记为活动状态,防止重复启动 + self.running = True + + self._loop_task = asyncio.create_task(self._main_chat_loop()) + self._loop_task.add_done_callback(self._handle_loop_completion) + + # 启动聊天内容概括器的后台定期检查循环 + await self.chat_history_summarizer.start() + + logger.info(f"{self.log_prefix} HeartFChatting 启动完成") + + except Exception as e: + # 启动失败时重置状态 + self.running = False + self._loop_task = None + logger.error(f"{self.log_prefix} HeartFChatting 启动失败: {e}") + raise + + def _handle_loop_completion(self, task: asyncio.Task): + """当 _hfc_loop 任务完成时执行的回调。""" + try: + if exception := task.exception(): + logger.error(f"{self.log_prefix} HeartFChatting: 脱离了聊天(异常): {exception}") + logger.error(traceback.format_exc()) # Log full traceback for exceptions + else: + logger.info(f"{self.log_prefix} HeartFChatting: 脱离了聊天 (外部停止)") + except asyncio.CancelledError: + logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天") + + def start_cycle(self) -> Tuple[Dict[str, float], str]: + self._cycle_counter += 1 + self._current_cycle_detail = CycleDetail(self._cycle_counter) + self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}" + cycle_timers = {} + return cycle_timers, self._current_cycle_detail.thinking_id + + def end_cycle(self, loop_info, cycle_timers): + self._current_cycle_detail.set_loop_info(loop_info) + self.history_loop.append(self._current_cycle_detail) + self._current_cycle_detail.timers = cycle_timers + self._current_cycle_detail.end_time = time.time() + + def print_cycle_info(self, cycle_timers): + # 记录循环信息和计时器结果 + timer_strings = [] + for name, elapsed in cycle_timers.items(): + if elapsed < 0.1: + # 不显示小于0.1秒的计时器 + continue + formatted_time = f"{elapsed:.2f}秒" + timer_strings.append(f"{name}: {formatted_time}") + + logger.info( + f"{self.log_prefix} 第{self._current_cycle_detail.cycle_id}次思考," + f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒;" # type: ignore + + (f"详情: {'; '.join(timer_strings)}" if timer_strings else "") + ) + + async def _loopbody(self): + recent_messages_list = message_api.get_messages_by_time_in_chat( + chat_id=self.session_id, + start_time=self.last_read_time, + end_time=time.time(), + limit=20, + limit_mode="latest", + filter_mai=True, + filter_command=False, + filter_intercept_message_level=0, + ) + + # 根据连续 no_reply 次数动态调整阈值 + # 3次 no_reply 时,阈值调高到 1.5(50%概率为1,50%概率为2) + # 5次 no_reply 时,提高到 2(大于等于两条消息的阈值) + if self.consecutive_no_reply_count >= 5: + threshold = 2 + elif self.consecutive_no_reply_count >= 3: + # 1.5 的含义:50%概率为1,50%概率为2 + threshold = 2 if random.random() < 0.5 else 1 + else: + threshold = 1 + + if len(recent_messages_list) >= threshold: + # for message in recent_messages_list: + # print(message.processed_plain_text) + + self.last_read_time = time.time() + + # !此处使at或者提及必定回复 + mentioned_message = None + for message in recent_messages_list: + if (message.is_mentioned or message.is_at) and global_config.chat.mentioned_bot_reply: + mentioned_message = message + + # logger.info(f"{self.log_prefix} 当前talk_value: {TempMethods.get_talk_value(self.stream_id)}") + + # *控制频率用 + if mentioned_message: + await self._observe(recent_messages_list=recent_messages_list, force_reply_message=mentioned_message) + elif ( + random.random() + < TempMethodsHFC.get_talk_value(self.session_id) + * frequency_control_manager.get_or_create_frequency_control(self.session_id).get_talk_frequency_adjust() + ): + await self._observe(recent_messages_list=recent_messages_list) + else: + # 没有提到,继续保持沉默,等待5秒防止频繁触发 + await asyncio.sleep(10) + return True + else: + await asyncio.sleep(0.2) + return True + return True + + async def _send_and_store_reply( + self, + response_set: "ReplySetModel", + action_message: "DatabaseMessages", + cycle_timers: Dict[str, float], + thinking_id, + actions, + selected_expressions: Optional[List[int]] = None, + quote_message: Optional[bool] = None, + ) -> Tuple[Dict[str, Any], str, Dict[str, float]]: + with Timer("回复发送", cycle_timers): + reply_text = await self._send_response( + reply_set=response_set, + message_data=action_message, + selected_expressions=selected_expressions, + quote_message=quote_message, + ) + + # 获取 platform,如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值 + platform = action_message.chat_info.platform + if platform is None: + platform = getattr(self.chat_stream, "platform", "unknown") + + person = Person(platform=platform, user_id=action_message.user_info.user_id) + person_name = person.person_name + action_prompt_display = f"你对{person_name}进行了回复:{reply_text}" + + await database_api.store_action_info( + chat_stream=self.chat_stream, + action_build_into_prompt=False, + action_prompt_display=action_prompt_display, + action_done=True, + thinking_id=thinking_id, + action_data={"reply_text": reply_text}, + action_name="reply", + ) + + # 构建循环信息 + loop_info: Dict[str, Any] = { + "loop_plan_info": { + "action_result": actions, + }, + "loop_action_info": { + "action_taken": True, + "reply_text": reply_text, + "command": "", + "taken_time": time.time(), + }, + } + + return loop_info, reply_text, cycle_timers + + async def _observe( + self, # interest_value: float = 0.0, + recent_messages_list: Optional[List["DatabaseMessages"]] = None, + force_reply_message: Optional["DatabaseMessages"] = None, + ) -> bool: # sourcery skip: merge-else-if-into-elif, remove-redundant-if + if recent_messages_list is None: + recent_messages_list = [] + _reply_text = "" # 初始化reply_text变量,避免UnboundLocalError + + # ------------------------------------------------------------------------- + # ReflectTracker Check + # 在每次回复前检查一次上下文,看是否有反思问题得到了解答 + # ------------------------------------------------------------------------- + + reflector = expression_reflector_manager.get_or_create_reflector(self.session_id) + await reflector.check_and_ask() + tracker = reflect_tracker_manager.get_tracker(self.session_id) + if tracker: + resolved = await tracker.trigger_tracker() + if resolved: + reflect_tracker_manager.remove_tracker(self.session_id) + logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.") + + start_time = time.time() + async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): + # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner + # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息 + asyncio.create_task(extract_and_distribute_messages(self.session_id)) + + # 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容 + # asyncio.create_task(check_and_make_question(self.stream_id)) + # 添加聊天内容概括任务 - 累积、打包和压缩聊天记录 + # 注意:后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理 + # asyncio.create_task(self.chat_history_summarizer.process()) + + cycle_timers, thinking_id = self.start_cycle() + logger.info( + f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {TempMethodsHFC.get_talk_value(self.session_id)})" + ) + + # 第一步:动作检查 + available_actions: Dict[str, ActionInfo] = {} + try: + await self.action_modifier.modify_actions() + available_actions = self.action_manager.get_using_actions() + except Exception as e: + logger.error(f"{self.log_prefix} 动作修改失败: {e}") + + # 执行planner + is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info() + + message_list_before_now = get_raw_msg_before_timestamp_with_chat( + chat_id=self.session_id, + timestamp=time.time(), + limit=int(global_config.chat.max_context_size * 0.6), + filter_intercept_message_level=1, + ) + chat_content_block, message_id_list = build_readable_messages_with_id( + messages=message_list_before_now, + timestamp_mode="normal_no_YMD", + read_mark=self.action_planner.last_obs_time_mark, + truncate=True, + show_actions=True, + ) + + prompt_info = await self.action_planner.build_planner_prompt( + is_group_chat=is_group_chat, + chat_target_info=chat_target_info, + current_available_actions=available_actions, + chat_content_block=chat_content_block, + message_id_list=message_id_list, + ) + continue_flag, modified_message = await events_manager.handle_mai_events( + EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id + ) + if not continue_flag: + return False + if modified_message and modified_message._modify_flags.modify_llm_prompt: + prompt_info = (modified_message.llm_prompt, prompt_info[1]) + + with Timer("规划器", cycle_timers): + action_to_use_info = await self.action_planner.plan( + loop_start_time=self.last_read_time, + available_actions=available_actions, + force_reply_message=force_reply_message, + ) + + logger.info( + f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}" + ) + + # 3. 并行执行所有动作 + action_tasks = [ + asyncio.create_task( + self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers) + ) + for action in action_to_use_info + ] + + # 并行执行所有任务 + results = await asyncio.gather(*action_tasks, return_exceptions=True) + + # 处理执行结果 + reply_loop_info = None + reply_text_from_reply = "" + action_success = False + action_reply_text = "" + + excute_result_str = "" + for result in results: + excute_result_str += f"{result['action_type']} 执行结果:{result['result']}\n" + + if isinstance(result, BaseException): + logger.error(f"{self.log_prefix} 动作执行异常: {result}") + continue + + if result["action_type"] != "reply": + action_success = result["success"] + action_reply_text = result["result"] + elif result["action_type"] == "reply": + if result["success"]: + reply_loop_info = result["loop_info"] + reply_text_from_reply = result["result"] + else: + logger.warning(f"{self.log_prefix} 回复动作执行失败") + + self.action_planner.add_plan_excute_log(result=excute_result_str) + + # 构建最终的循环信息 + if reply_loop_info: + # 如果有回复信息,使用回复的loop_info作为基础 + loop_info = reply_loop_info + # 更新动作执行信息 + loop_info["loop_action_info"].update( + { + "action_taken": action_success, + "taken_time": time.time(), + } + ) + _reply_text = reply_text_from_reply + else: + # 没有回复信息,构建纯动作的loop_info + loop_info = { + "loop_plan_info": { + "action_result": action_to_use_info, + }, + "loop_action_info": { + "action_taken": action_success, + "reply_text": action_reply_text, + "taken_time": time.time(), + }, + } + _reply_text = action_reply_text + + self.end_cycle(loop_info, cycle_timers) + self.print_cycle_info(cycle_timers) + + end_time = time.time() + if end_time - start_time < global_config.chat.planner_smooth: + wait_time = global_config.chat.planner_smooth - (end_time - start_time) + await asyncio.sleep(wait_time) + else: + await asyncio.sleep(0.1) + return True + + # async def _main_chat_loop(self): + # """主循环,持续进行计划并可能回复消息,直到被外部取消。""" + # try: + # while self.running: + # # 主循环 + # success = await self._loopbody() + # await asyncio.sleep(0.1) + # if not success: + # break + # except asyncio.CancelledError: + # # 设置了关闭标志位后被取消是正常流程 + # logger.info(f"{self.log_prefix} 麦麦已关闭聊天") + # except Exception: + # logger.error(f"{self.log_prefix} 麦麦聊天意外错误,将于3s后尝试重新启动") + # print(traceback.format_exc()) + # await asyncio.sleep(3) + # self._loop_task = asyncio.create_task(self._main_chat_loop()) + # logger.error(f"{self.log_prefix} 结束了当前聊天循环") + + async def _handle_action( + self, + action: str, + action_reasoning: str, + action_data: dict, + cycle_timers: Dict[str, float], + thinking_id: str, + action_message: Optional["DatabaseMessages"] = None, + ) -> tuple[bool, str, str]: + """ + 处理规划动作,使用动作工厂创建相应的动作处理器 + + 参数: + action: 动作类型 + action_reasoning: 决策理由 + action_data: 动作数据,包含不同动作需要的参数 + cycle_timers: 计时器字典 + thinking_id: 思考ID + action_message: 消息数据 + 返回: + tuple[bool, str, str]: (是否执行了动作, 思考消息ID, 命令) + """ + try: + # 使用工厂创建动作处理器实例 + try: + action_handler = self.action_manager.create_action( + action_name=action, + action_data=action_data, + cycle_timers=cycle_timers, + thinking_id=thinking_id, + chat_stream=self.chat_stream, + log_prefix=self.log_prefix, + action_reasoning=action_reasoning, + action_message=action_message, + ) + except Exception as e: + logger.error(f"{self.log_prefix} 创建动作处理器时出错: {e}") + traceback.print_exc() + return False, "" + + # 处理动作并获取结果(固定记录一次动作信息) + result = await action_handler.execute() + success, action_text = result + + return success, action_text + + except Exception as e: + logger.error(f"{self.log_prefix} 处理{action}时出错: {e}") + traceback.print_exc() + return False, "" + + async def _send_response( + self, + reply_set: "ReplySetModel", + message_data: "DatabaseMessages", + selected_expressions: Optional[List[int]] = None, + quote_message: Optional[bool] = None, + ) -> str: + # 根据 llm_quote 配置决定是否使用 quote_message 参数 + if global_config.chat.llm_quote: + # 如果配置为 true,使用 llm_quote 参数决定是否引用回复 + if quote_message is None: + logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用") + need_reply = False + else: + need_reply = quote_message + if need_reply: + logger.info(f"{self.log_prefix} LLM 决定使用引用回复") + else: + # 如果配置为 false,使用原来的模式 + new_message_count = message_api.count_new_messages( + chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time() + ) + need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90 + if need_reply: + logger.info( + f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复,或者上次回复时间超过90秒" + ) + + reply_text = "" + first_replied = False + for reply_content in reply_set.reply_data: + if reply_content.content_type != ReplyContentType.TEXT: + continue + data: str = reply_content.content # type: ignore + if not first_replied: + await send_api.text_to_stream( + text=data, + stream_id=self.chat_stream.stream_id, + reply_message=message_data, + set_reply=need_reply, + typing=False, + selected_expressions=selected_expressions, + ) + first_replied = True + else: + await send_api.text_to_stream( + text=data, + stream_id=self.chat_stream.stream_id, + reply_message=message_data, + set_reply=False, + typing=True, + selected_expressions=selected_expressions, + ) + reply_text += data + + return reply_text + + async def _execute_action( + self, + action_planner_info: ActionPlannerInfo, + chosen_action_plan_infos: List[ActionPlannerInfo], + thinking_id: str, + available_actions: Dict[str, ActionInfo], + cycle_timers: Dict[str, float], + ): + """执行单个动作的通用函数""" + try: + with Timer(f"动作{action_planner_info.action_type}", cycle_timers): + # 直接当场执行no_reply逻辑 + if action_planner_info.action_type == "no_reply": + # 直接处理no_reply逻辑,不再通过动作系统 + reason = action_planner_info.reasoning or "选择不回复" + # logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}") + + # 增加连续 no_reply 计数 + self.consecutive_no_reply_count += 1 + + await database_api.store_action_info( + chat_stream=self.chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={}, + action_name="no_reply", + action_reasoning=reason, + ) + + return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""} + + elif action_planner_info.action_type == "reply": + # 直接当场执行reply逻辑 + self.questioned = False + # 刷新主动发言状态 + # 重置连续 no_reply 计数 + self.consecutive_no_reply_count = 0 + + reason = action_planner_info.reasoning or "" + # 根据 think_mode 配置决定 think_level 的值 + think_mode = global_config.chat.think_mode + if think_mode == "default": + think_level = 0 + elif think_mode == "deep": + think_level = 1 + elif think_mode == "dynamic": + # dynamic 模式:从 planner 返回的 action_data 中获取 + think_level = action_planner_info.action_data.get("think_level", 1) + else: + # 默认使用 default 模式 + think_level = 0 + # 使用 action_reasoning(planner 的整体思考理由)作为 reply_reason + planner_reasoning = action_planner_info.action_reasoning or reason + + record_replyer_action_temp( + chat_id=self.session_id, + reason=reason, + think_level=think_level, + ) + + await database_api.store_action_info( + chat_stream=self.chat_stream, + action_build_into_prompt=False, + action_prompt_display=reason, + action_done=True, + thinking_id=thinking_id, + action_data={}, + action_name="reply", + action_reasoning=reason, + ) + + # 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用) + unknown_words = None + quote_message = None + if isinstance(action_planner_info.action_data, dict): + uw = action_planner_info.action_data.get("unknown_words") + if isinstance(uw, list): + cleaned_uw: List[str] = [] + for item in uw: + if isinstance(item, str): + s = item.strip() + if s: + cleaned_uw.append(s) + if cleaned_uw: + unknown_words = cleaned_uw + + # 从 Planner 的 action_data 中提取 quote_message 参数 + qm = action_planner_info.action_data.get("quote") + if qm is not None: + # 支持多种格式:true/false, "true"/"false", 1/0 + if isinstance(qm, bool): + quote_message = qm + elif isinstance(qm, str): + quote_message = qm.lower() in ("true", "1", "yes") + elif isinstance(qm, (int, float)): + quote_message = bool(qm) + + logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}") + + success, llm_response = await generator_api.generate_reply( + chat_stream=self.chat_stream, + reply_message=action_planner_info.action_message, + available_actions=available_actions, + chosen_actions=chosen_action_plan_infos, + reply_reason=planner_reasoning, + unknown_words=unknown_words, + enable_tool=global_config.tool.enable_tool, + request_type="replyer", + from_plugin=False, + reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()), + think_level=think_level, + ) + + if not success or not llm_response or not llm_response.reply_set: + if action_planner_info.action_message: + logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败") + else: + logger.info("回复生成失败") + return {"action_type": "reply", "success": False, "result": "回复生成失败", "loop_info": None} + + response_set = llm_response.reply_set + selected_expressions = llm_response.selected_expressions + loop_info, reply_text, _ = await self._send_and_store_reply( + response_set=response_set, + action_message=action_planner_info.action_message, # type: ignore + cycle_timers=cycle_timers, + thinking_id=thinking_id, + actions=chosen_action_plan_infos, + selected_expressions=selected_expressions, + quote_message=quote_message, + ) + self.last_active_time = time.time() + return { + "action_type": "reply", + "success": True, + "result": f"你使用reply动作,对' {action_planner_info.action_message.processed_plain_text} '这句话进行了回复,回复内容为: '{reply_text}'", + "loop_info": loop_info, + } + + else: + # 执行普通动作 + with Timer("动作执行", cycle_timers): + success, result = await self._handle_action( + action=action_planner_info.action_type, + action_reasoning=action_planner_info.action_reasoning or "", + action_data=action_planner_info.action_data or {}, + cycle_timers=cycle_timers, + thinking_id=thinking_id, + action_message=action_planner_info.action_message, + ) + + self.last_active_time = time.time() + return { + "action_type": action_planner_info.action_type, + "success": success, + "result": result, + } + + except Exception as e: + logger.error(f"{self.log_prefix} 执行动作时出错: {e}") + logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}") + return { + "action_type": action_planner_info.action_type, + "success": False, + "result": "", + "loop_info": None, + "error": str(e), + } + + +class TempMethodsHFC: + @staticmethod + def get_talk_value(chat_id: Optional[str]) -> float: + result = global_config.chat.talk_value or 0.0000001 + if not global_config.chat.enable_talk_value_rules or not global_config.chat.talk_value_rules: + return result + import time + + local_time = time.localtime() + now_min = local_time.tm_hour * 60 + local_time.tm_min + # 先处理特定规则 + if chat_id: + for rule in global_config.chat.talk_value_rules: + if not rule.platform and not rule.item_id: + continue # 一起留空表示全局,跳过 + is_group = rule.rule_type == "group" + from src.chat.message_receive.chat_stream import get_chat_manager + + stream_id = get_chat_manager().get_stream_id(rule.platform, str(rule.item_id), is_group) + if stream_id != chat_id: + continue + parsed_range = TempMethodsHFC._parse_range(rule.time) + if not parsed_range: + continue + start_min, end_min = parsed_range + in_range: bool = False + if start_min <= end_min: + in_range = start_min <= now_min <= end_min + else: + in_range = now_min >= start_min or now_min <= end_min + if in_range: + return rule.value or 0.0 + # 再处理全局规则 + for rule in global_config.chat.talk_value_rules: + if rule.platform or rule.item_id: + continue # 有指定表示特定,跳过 + parsed_range = TempMethodsHFC._parse_range(rule.time) + if not parsed_range: + continue + start_min, end_min = parsed_range + in_range: bool = False + if start_min <= end_min: + in_range = start_min <= now_min <= end_min + else: + in_range = now_min >= start_min or now_min <= end_min + if in_range: + return rule.value or 0.0000001 + return result + + @staticmethod + def _parse_range(range_str: str) -> Optional[tuple[int, int]]: + """解析 "HH:MM-HH:MM" 到 (start_min, end_min)。""" + try: + start_str, end_str = [s.strip() for s in range_str.split("-")] + sh, sm = [int(x) for x in start_str.split(":")] + eh, em = [int(x) for x in end_str.split(":")] + return sh * 60 + sm, eh * 60 + em + except Exception: + return None diff --git a/src/chat/heart_flow/hfc_utils.py b/src/chat/heart_flow/hfc_utils.py index 36820d1c..20843bf6 100644 --- a/src/chat/heart_flow/hfc_utils.py +++ b/src/chat/heart_flow/hfc_utils.py @@ -12,13 +12,14 @@ from src.common.message_repository import count_messages logger = get_logger(__name__) + @dataclass -class CyclePlanInfo: - ... - +class CyclePlanInfo: ... + + @dataclass -class CycleActionInfo: - ... +class CycleActionInfo: ... + class CycleDetail: """循环信息记录类""" diff --git a/src/chat/message_receive/message_old.py b/src/chat/message_receive/message_old.py new file mode 100644 index 00000000..b9f44d5c --- /dev/null +++ b/src/chat/message_receive/message_old.py @@ -0,0 +1,561 @@ +import time +import asyncio +import urllib3 + +from abc import abstractmethod +from dataclasses import dataclass +from rich.traceback import install +from typing import Optional, Any, List +from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.utils.utils_image import get_image_manager +from src.common.utils.utils_voice import get_voice_text +from .chat_stream import ChatStream + +install(extra_lines=3) + +logger = get_logger("chat_message") + +# 禁用SSL警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +# VLM 处理并发限制(避免同时处理太多图片导致卡死) +_vlm_semaphore = asyncio.Semaphore(3) + +# 这个类是消息数据类,用于存储和管理消息数据。 +# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 +# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 + + +@dataclass +class Message(MessageBase): + chat_stream: "ChatStream" = None # type: ignore + reply: Optional["Message"] = None + processed_plain_text: str = "" + + def __init__( + self, + message_id: str, + chat_stream: "ChatStream", + user_info: UserInfo, + message_segment: Optional[Seg] = None, + timestamp: Optional[float] = None, + reply: Optional["MessageRecv"] = None, + processed_plain_text: str = "", + ): + # 使用传入的时间戳或当前时间 + current_timestamp = timestamp if timestamp is not None else round(time.time(), 3) + # 构造基础消息信息 + message_info = BaseMessageInfo( + platform=chat_stream.platform, + message_id=message_id, + time=current_timestamp, + group_info=chat_stream.group_info, + user_info=user_info, + ) + + # 调用父类初始化 + super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore + + self.chat_stream = chat_stream + # 文本处理相关属性 + self.processed_plain_text = processed_plain_text + + # 回复消息 + self.reply = reply + + # async def _process_message_segments(self, segment: Seg) -> str: + # # sourcery skip: remove-unnecessary-else, swap-if-else-branches + # """递归处理消息段,转换为文字描述 + + # Args: + # segment: 要处理的消息段 + + # Returns: + # str: 处理后的文本 + # """ + # if segment.type == "seglist": + # # 处理消息段列表 - 使用并行处理提升性能 + # tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore + # results = await asyncio.gather(*tasks, return_exceptions=True) + # segments_text = [] + # for result in results: + # if isinstance(result, Exception): + # logger.error(f"处理消息段时出错: {result}") + # continue + # if result: + # segments_text.append(result) + # return " ".join(segments_text) + # elif segment.type == "forward": + # # 处理转发消息 - 使用并行处理 + # async def process_forward_node(node_dict): + # message = MessageBase.from_dict(node_dict) # type: ignore + # processed_text = await self._process_message_segments(message.message_segment) + # if processed_text: + # return f"{global_config.bot.nickname}: {processed_text}" + # return None + + # tasks = [process_forward_node(node_dict) for node_dict in segment.data] + # results = await asyncio.gather(*tasks, return_exceptions=True) + # segments_text = [] + # for result in results: + # if isinstance(result, Exception): + # logger.error(f"处理转发节点时出错: {result}") + # continue + # if result: + # segments_text.append(result) + # return "[合并消息]: " + "\n-- ".join(segments_text) + # else: + # # 处理单个消息段 + # return await self._process_single_segment(segment) # type: ignore + + # @abstractmethod + # async def _process_single_segment(self, segment) -> str: + # pass + + +@dataclass +class MessageRecv(Message): + """接收消息类,用于处理从MessageCQ序列化的消息""" + + def __init__(self, message_dict: dict[str, Any]): + """从MessageCQ的字典初始化 + + Args: + message_dict: MessageCQ序列化后的字典 + """ + self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) + self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) + self.raw_message = message_dict.get("raw_message") + self.processed_plain_text = message_dict.get("processed_plain_text", "") + self.is_emoji = False + self.has_emoji = False + self.is_picid = False + self.has_picid = False + self.is_voice = False + self.is_mentioned = None + self.is_at = False + self.reply_probability_boost = 0.0 + self.is_notify = False + + self.is_command = False + self.intercept_message_level = 0 + + self.priority_mode = "interest" + self.priority_info = None + self.interest_value: float = None # type: ignore + + self.key_words = [] + self.key_words_lite = [] + + # 兼容适配器通过 additional_config 传入的 @ 标记 + try: + msg_info_dict = message_dict.get("message_info", {}) + add_cfg = msg_info_dict.get("additional_config") or {} + if isinstance(add_cfg, dict) and add_cfg.get("at_bot"): + # 标记为被提及,提高后续回复优先级 + self.is_mentioned = True # type: ignore + except Exception: + pass + + def update_chat_stream(self, chat_stream: "ChatStream"): + self.chat_stream = chat_stream + + # async def process(self) -> None: + # """处理消息内容,生成纯文本和详细文本 + + # 这个方法必须在创建实例后显式调用,因为它包含异步操作。 + # """ + # # print(f"self.message_segment: {self.message_segment}") + # self.processed_plain_text = await self._process_message_segments(self.message_segment) + + # async def _process_single_segment(self, segment: Seg) -> str: + # """处理单个消息段 + + # Args: + # segment: 消息段 + + # Returns: + # str: 处理后的文本 + # """ + # try: + # if segment.type == "text": + # self.is_picid = False + # self.is_emoji = False + # return segment.data # type: ignore + # elif segment.type == "image": + # # 如果是base64图片数据 + # if isinstance(segment.data, str): + # self.has_picid = True + # self.is_picid = True + # self.is_emoji = False + # image_manager = get_image_manager() + # # 使用 semaphore 限制 VLM 并发,避免同时处理太多图片 + # async with _vlm_semaphore: + # _, processed_text = await image_manager.process_image(segment.data) + # return processed_text + # return "[发了一张图片,网卡了加载不出来]" + # elif segment.type == "emoji": + # self.has_emoji = True + # self.is_emoji = True + # self.is_picid = False + # self.is_voice = False + # if isinstance(segment.data, str): + # # 使用 semaphore 限制 VLM 并发 + # async with _vlm_semaphore: + # return await get_image_manager().get_emoji_description(segment.data) + # return "[发了一个表情包,网卡了加载不出来]" + # elif segment.type == "voice": + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = True + # if isinstance(segment.data, str): + # return await get_voice_text(segment.data) + # return "[发了一段语音,网卡了加载不出来]" + # elif segment.type == "mention_bot": + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = False + # self.is_mentioned = float(segment.data) # type: ignore + # return "" + # elif segment.type == "priority_info": + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = False + # if isinstance(segment.data, dict): + # # 处理优先级信息 + # self.priority_mode = "priority" + # self.priority_info = segment.data + # """ + # { + # 'message_type': 'vip', # vip or normal + # 'message_priority': 1.0, # 优先级,大为优先,float + # } + # """ + # return "" + # elif segment.type == "video_card": + # # 处理视频卡片消息 + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = False + # if isinstance(segment.data, dict): + # file_name = segment.data.get("file", "未知视频") + # file_size = segment.data.get("file_size", "") + # url = segment.data.get("url", "") + # text = f"[视频: {file_name}" + # if file_size: + # text += f", 大小: {file_size}字节" + # text += "]" + # if url: + # text += f" 链接: {url}" + # return text + # return "[视频]" + # elif segment.type == "music_card": + # # 处理音乐卡片消息 + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = False + # if isinstance(segment.data, dict): + # title = segment.data.get("title", "未知歌曲") + # singer = segment.data.get("singer", "") + # tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐" + # jump_url = segment.data.get("jump_url", "") + # music_url = segment.data.get("music_url", "") + # text = f"[音乐: {title}" + # if singer: + # text += f" - {singer}" + # if tag: + # text += f" ({tag})" + # text += "]" + # if jump_url: + # text += f" 跳转链接: {jump_url}" + # if music_url: + # text += f" 音乐链接: {music_url}" + # return text + # return "[音乐]" + # elif segment.type == "miniapp_card": + # # 处理小程序分享卡片(如B站视频分享) + # self.is_picid = False + # self.is_emoji = False + # self.is_voice = False + # if isinstance(segment.data, dict): + # title = segment.data.get("title", "") # 小程序名称 + # desc = segment.data.get("desc", "") # 内容描述 + # source_url = segment.data.get("source_url", "") # 原始链接 + # url = segment.data.get("url", "") # 小程序链接 + # text = "[小程序分享" + # if title: + # text += f" - {title}" + # text += "]" + # if desc: + # text += f" {desc}" + # if source_url: + # text += f" 链接: {source_url}" + # elif url: + # text += f" 链接: {url}" + # return text + # return "[小程序分享]" + # else: + # return "" + # except Exception as e: + # logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + # return f"[处理失败的{segment.type}消息]" + + +@dataclass +class MessageProcessBase(Message): + """消息处理基类,用于处理中和发送中的消息""" + + def __init__( + self, + message_id: str, + chat_stream: "ChatStream", + bot_user_info: UserInfo, + message_segment: Optional[Seg] = None, + reply: Optional["MessageRecv"] = None, + thinking_start_time: float = 0, + timestamp: Optional[float] = None, + ): + # 调用父类初始化,传递时间戳 + super().__init__( + message_id=message_id, + timestamp=timestamp, + chat_stream=chat_stream, + user_info=bot_user_info, + message_segment=message_segment, + reply=reply, + ) + + # 处理状态相关属性 + self.thinking_start_time = thinking_start_time + self.thinking_time = 0 + + # def update_thinking_time(self) -> float: + # """更新思考时间""" + # self.thinking_time = round(time.time() - self.thinking_start_time, 2) + # return self.thinking_time + + # async def _process_single_segment(self, segment: Seg) -> str: + # """处理单个消息段 + + # Args: + # segment: 要处理的消息段 + + # Returns: + # str: 处理后的文本 + # """ + # try: + # if segment.type == "text": + # return segment.data # type: ignore + # elif segment.type == "image": + # # 如果是base64图片数据 + # if isinstance(segment.data, str): + # return await get_image_manager().get_image_description(segment.data) + # return "[图片,网卡了加载不出来]" + # elif segment.type == "emoji": + # if isinstance(segment.data, str): + # return await get_image_manager().get_emoji_tag(segment.data) + # return "[表情,网卡了加载不出来]" + # elif segment.type == "voice": + # if isinstance(segment.data, str): + # return await get_voice_text(segment.data) + # return "[发了一段语音,网卡了加载不出来]" + # elif segment.type == "at": + # return f"[@{segment.data}]" + # elif segment.type == "reply": + # if self.reply and hasattr(self.reply, "processed_plain_text"): + # # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") + # # print(f"reply: {self.reply}") + # return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore + # return "" + # else: + # return f"[{segment.type}:{str(segment.data)}]" + # except Exception as e: + # logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") + # return f"[处理失败的{segment.type}消息]" + + # def _generate_detailed_text(self) -> str: + # """生成详细文本,包含时间和用户信息""" + # # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) + # timestamp = self.message_info.time + # user_info = self.message_info.user_info + + # name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore + # return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n" + + +@dataclass +class MessageSending(MessageProcessBase): + """发送状态的消息类""" + + def __init__( + self, + message_id: str, + chat_stream: "ChatStream", + bot_user_info: UserInfo, + sender_info: UserInfo | None, # 用来记录发送者信息 + message_segment: Seg, + display_message: str = "", + reply: Optional["MessageRecv"] = None, + is_head: bool = False, + is_emoji: bool = False, + thinking_start_time: float = 0, + apply_set_reply_logic: bool = False, + reply_to: Optional[str] = None, + selected_expressions: Optional[List[int]] = None, + ): + # 调用父类初始化 + super().__init__( + message_id=message_id, + chat_stream=chat_stream, + bot_user_info=bot_user_info, + message_segment=message_segment, + reply=reply, + thinking_start_time=thinking_start_time, + ) + + # 发送状态特有属性 + self.sender_info = sender_info + self.reply_to_message_id = reply.message_info.message_id if reply else None + self.is_head = is_head + self.is_emoji = is_emoji + self.apply_set_reply_logic = apply_set_reply_logic + + self.reply_to = reply_to + + # 用于显示发送内容与显示不一致的情况 + self.display_message = display_message + + self.interest_value = 0.0 + + self.selected_expressions = selected_expressions + + def build_reply(self): + """设置回复消息""" + if self.reply: + self.reply_to_message_id = self.reply.message_info.message_id + self.message_segment = Seg( + type="seglist", + data=[ + Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore + self.message_segment, + ], + ) + + async def process(self) -> None: + """处理消息内容,生成纯文本和详细文本""" + if self.message_segment: + self.processed_plain_text = await self._process_message_segments(self.message_segment) + + # def to_dict(self): + # ret = super().to_dict() + # ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict() + # return ret + + # def is_private_message(self) -> bool: + # """判断是否为私聊消息""" + # return self.message_info.group_info is None or self.message_info.group_info.group_id is None + + +# @dataclass +# class MessageSet: +# """消息集合类,可以存储多个发送消息""" + +# def __init__(self, chat_stream: "ChatStream", message_id: str): +# self.chat_stream = chat_stream +# self.message_id = message_id +# self.messages: list[MessageSending] = [] +# self.time = round(time.time(), 3) # 保留3位小数 + +# def add_message(self, message: MessageSending) -> None: +# """添加消息到集合""" +# if not isinstance(message, MessageSending): +# raise TypeError("MessageSet只能添加MessageSending类型的消息") +# self.messages.append(message) +# self.messages.sort(key=lambda x: x.message_info.time) # type: ignore + +# def get_message_by_index(self, index: int) -> Optional[MessageSending]: +# """通过索引获取消息""" +# return self.messages[index] if 0 <= index < len(self.messages) else None + +# def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: +# """获取最接近指定时间的消息""" +# if not self.messages: +# return None + +# left, right = 0, len(self.messages) - 1 +# while left < right: +# mid = (left + right) // 2 +# if self.messages[mid].message_info.time < target_time: # type: ignore +# left = mid + 1 +# else: +# right = mid + +# return self.messages[left] + +# def clear_messages(self) -> None: +# """清空所有消息""" +# self.messages.clear() + +# def remove_message(self, message: MessageSending) -> bool: +# """移除指定消息""" +# if message in self.messages: +# self.messages.remove(message) +# return True +# return False + +# def __str__(self) -> str: +# return f"MessageSet(id={self.message_id}, count={len(self.messages)})" + +# def __len__(self) -> int: +# return len(self.messages) + + +# def message_recv_from_dict(message_dict: dict) -> MessageRecv: +# return MessageRecv(message_dict) + + +# def message_from_db_dict(db_dict: dict) -> MessageRecv: +# """从数据库字典创建MessageRecv实例""" +# # 转换扁平的数据库字典为嵌套结构 +# message_info_dict = { +# "platform": db_dict.get("chat_info_platform"), +# "message_id": db_dict.get("message_id"), +# "time": db_dict.get("time"), +# "group_info": { +# "platform": db_dict.get("chat_info_group_platform"), +# "group_id": db_dict.get("chat_info_group_id"), +# "group_name": db_dict.get("chat_info_group_name"), +# }, +# "user_info": { +# "platform": db_dict.get("user_platform"), +# "user_id": db_dict.get("user_id"), +# "user_nickname": db_dict.get("user_nickname"), +# "user_cardname": db_dict.get("user_cardname"), +# }, +# } + +# processed_text = db_dict.get("processed_plain_text", "") + +# # 构建 MessageRecv 需要的字典 +# recv_dict = { +# "message_info": message_info_dict, +# "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 +# "raw_message": None, # 数据库中未存储原始消息 +# "processed_plain_text": processed_text, +# } + +# # 创建 MessageRecv 实例 +# msg = MessageRecv(recv_dict) + +# # 从数据库字典中填充其他可选字段 +# msg.interest_value = db_dict.get("interest_value", 0.0) +# msg.is_mentioned = db_dict.get("is_mentioned") +# msg.priority_mode = db_dict.get("priority_mode", "interest") +# msg.priority_info = db_dict.get("priority_info") +# msg.is_emoji = db_dict.get("is_emoji", False) +# msg.is_picid = db_dict.get("is_picid", False) + +# return msg diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index de980d9c..9346259c 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -7,7 +7,7 @@ from maim_message import Seg from src.common.message_server.api import get_global_api from src.common.logger import get_logger from src.common.database.database import get_db_session -from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.message_old import MessageSending from src.chat.utils.utils import truncate_message from src.chat.utils.utils import calculate_typing_time diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 1889f144..441e8bb8 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -12,11 +12,8 @@ from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from maim_message import Seg - -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo -from src.chat.message_receive.message import MessageSending -from src.chat.message_receive.chat_manager import BotChatSession +from src.chat.message_receive.message_old import UserInfo, Seg, MessageRecv, MessageSending +from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self @@ -27,16 +24,16 @@ from src.chat.utils.chat_message_builder import ( replace_user_references, ) from src.bw_learner.expression_selector import expression_selector -from src.services.message_service import translate_pid_to_description +from src.plugin_system.apis.message_api import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person -from src.core.types import ActionInfo, EventType -from src.services import llm_service as llm_api +from src.plugin_system.base.component_types import ActionInfo, EventType +from src.plugin_system.apis import llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.bw_learner.jargon_explainer import explain_jargon_in_context, retrieve_concepts_with_jargon +from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon from src.chat.utils.common_utils import TempMethodsExpression init_memory_retrieval_sys() @@ -48,17 +45,17 @@ logger = get_logger("replyer") class DefaultReplyer: def __init__( self, - chat_stream: BotChatSession, + chat_stream: ChatStream, request_type: str = "replyer", ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.heart_fc_sender = UniversalMessageSender() - from src.chat.tool_executor import ToolExecutor + from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) async def generate_reply_with_context( self, @@ -135,7 +132,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.session_id, + chat_id=self.chat_stream.stream_id, prompt="", output=None, processed_output=None, @@ -149,13 +146,11 @@ class DefaultReplyer: except Exception: logger.exception("记录reply日志失败") return False, llm_response - from src.core.event_bus import event_bus - from src.chat.event_helpers import build_event_message + from src.plugin_system.core.events_manager import events_manager if not from_plugin: - _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) - continue_flag, modified_message = await event_bus.emit( - EventType.POST_LLM, _event_msg + continue_flag, modified_message = await events_manager.handle_mai_events( + EventType.POST_LLM, None, prompt, None, stream_id=stream_id ) if not continue_flag: raise UserWarning("插件于请求前中断了内容生成") @@ -207,7 +202,7 @@ class DefaultReplyer: try: if log_reply: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.session_id, + chat_id=self.chat_stream.stream_id, prompt=prompt, output=content, processed_output=None, @@ -219,9 +214,8 @@ class DefaultReplyer: ) except Exception: logger.exception("记录reply日志失败") - _event_msg = build_event_message(EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id) - continue_flag, modified_message = await event_bus.emit( - EventType.AFTER_LLM, _event_msg + continue_flag, modified_message = await events_manager.handle_mai_events( + EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id ) if not from_plugin and not continue_flag: raise UserWarning("插件于请求后取消了内容生成") @@ -265,7 +259,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.session_id, + chat_id=self.chat_stream.stream_id, prompt=prompt or "", output=None, processed_output=None, @@ -359,14 +353,14 @@ class DefaultReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) if not use_expression: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( - self.chat_stream.session_id, + self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, @@ -708,11 +702,10 @@ class DefaultReplyer: # 判断是否为群聊 is_group = stream_type == "group" - from src.common.utils.utils_session import SessionUtils + # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑 + from src.chat.message_receive.chat_stream import get_chat_manager - chat_id = SessionUtils.calculate_session_id( - platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None - ) + chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) return chat_id, prompt_content except (ValueError, IndexError): @@ -785,7 +778,7 @@ class DefaultReplyer: if available_actions is None: available_actions = {} chat_stream = self.chat_stream - chat_id = chat_stream.session_id + chat_id = chat_stream.stream_id _is_group_chat = bool(chat_stream.group_info) platform = chat_stream.platform @@ -1012,7 +1005,7 @@ class DefaultReplyer: reply_to: str, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream - chat_id = chat_stream.session_id + chat_id = chat_stream.stream_id sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -1112,27 +1105,29 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MaiMessage] = None, + anchor_message: Optional[MessageRecv] = None, ) -> MessageSending: """构建单个发送消息""" bot_user_info = UserInfo( user_id=str(global_config.bot.qq_account), user_nickname=global_config.bot.nickname, + platform=self.chat_stream.platform, ) + # await anchor_message.process() sender_info = anchor_message.message_info.user_info if anchor_message else None return MessageSending( - message_id=message_id, - session=self.chat_stream, + message_id=message_id, # 使用片段的唯一ID + chat_stream=self.chat_stream, bot_user_info=bot_user_info, sender_info=sender_info, message_segment=message_segment, - reply=anchor_message, + reply=anchor_message, # 回复原始锚点 is_head=reply_to, is_emoji=is_emoji, - thinking_start_time=thinking_start_time, + thinking_start_time=thinking_start_time, # 传递原始思考开始时间 display_message=display_message, ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index ee6f98bc..735e0afe 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -15,7 +15,7 @@ from src.llm_models.utils_model import LLMRequest from maim_message import Seg from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo -from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.message_old import MessageSending from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer @@ -35,7 +35,7 @@ from src.person_info.person_info import Person, is_person_known from src.core.types import ActionInfo, EventType from src.services import llm_service as llm_api from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt -from src.bw_learner.jargon_explainer import explain_jargon_in_context +from src.bw_learner.jargon_explainer_old import explain_jargon_in_context init_memory_retrieval_sys() diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 5bd35873..63e6bdce 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,212 +1,211 @@ import json -from typing import Optional, Any, Dict -from dataclasses import dataclass, field +from dataclasses import dataclass from . import BaseDataModel -@dataclass -class DatabaseUserInfo(BaseDataModel): - platform: str = field(default_factory=str) - user_id: str = field(default_factory=str) - user_nickname: str = field(default_factory=str) - user_cardname: Optional[str] = None +# @dataclass +# class DatabaseUserInfo(BaseDataModel): +# platform: str = field(default_factory=str) +# user_id: str = field(default_factory=str) +# user_nickname: str = field(default_factory=str) +# user_cardname: Optional[str] = None - # def __post_init__(self): - # assert isinstance(self.platform, str), "platform must be a string" - # assert isinstance(self.user_id, str), "user_id must be a string" - # assert isinstance(self.user_nickname, str), "user_nickname must be a string" - # assert isinstance(self.user_cardname, str) or self.user_cardname is None, ( - # "user_cardname must be a string or None" - # ) +# # def __post_init__(self): +# # assert isinstance(self.platform, str), "platform must be a string" +# # assert isinstance(self.user_id, str), "user_id must be a string" +# # assert isinstance(self.user_nickname, str), "user_nickname must be a string" +# # assert isinstance(self.user_cardname, str) or self.user_cardname is None, ( +# # "user_cardname must be a string or None" +# # ) -@dataclass -class DatabaseGroupInfo(BaseDataModel): - group_id: str = field(default_factory=str) - group_name: str = field(default_factory=str) - group_platform: Optional[str] = None +# @dataclass +# class DatabaseGroupInfo(BaseDataModel): +# group_id: str = field(default_factory=str) +# group_name: str = field(default_factory=str) +# group_platform: Optional[str] = None - # def __post_init__(self): - # assert isinstance(self.group_id, str), "group_id must be a string" - # assert isinstance(self.group_name, str), "group_name must be a string" - # assert isinstance(self.group_platform, str) or self.group_platform is None, ( - # "group_platform must be a string or None" - # ) +# # def __post_init__(self): +# # assert isinstance(self.group_id, str), "group_id must be a string" +# # assert isinstance(self.group_name, str), "group_name must be a string" +# # assert isinstance(self.group_platform, str) or self.group_platform is None, ( +# # "group_platform must be a string or None" +# # ) -@dataclass -class DatabaseChatInfo(BaseDataModel): - stream_id: str = field(default_factory=str) - platform: str = field(default_factory=str) - create_time: float = field(default_factory=float) - last_active_time: float = field(default_factory=float) - user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) - group_info: Optional[DatabaseGroupInfo] = None +# @dataclass +# class DatabaseChatInfo(BaseDataModel): +# stream_id: str = field(default_factory=str) +# platform: str = field(default_factory=str) +# create_time: float = field(default_factory=float) +# last_active_time: float = field(default_factory=float) +# user_info: DatabaseUserInfo = field(default_factory=DatabaseUserInfo) +# group_info: Optional[DatabaseGroupInfo] = None - # def __post_init__(self): - # assert isinstance(self.stream_id, str), "stream_id must be a string" - # assert isinstance(self.platform, str), "platform must be a string" - # assert isinstance(self.create_time, float), "create_time must be a float" - # assert isinstance(self.last_active_time, float), "last_active_time must be a float" - # assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" - # assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, ( - # "group_info must be a DatabaseGroupInfo instance or None" - # ) +# # def __post_init__(self): +# # assert isinstance(self.stream_id, str), "stream_id must be a string" +# # assert isinstance(self.platform, str), "platform must be a string" +# # assert isinstance(self.create_time, float), "create_time must be a float" +# # assert isinstance(self.last_active_time, float), "last_active_time must be a float" +# # assert isinstance(self.user_info, DatabaseUserInfo), "user_info must be a DatabaseUserInfo instance" +# # assert isinstance(self.group_info, DatabaseGroupInfo) or self.group_info is None, ( +# # "group_info must be a DatabaseGroupInfo instance or None" +# # ) -@dataclass(init=False) -class DatabaseMessages(BaseDataModel): - def __init__( - self, - message_id: str = "", - time: float = 0.0, - chat_id: str = "", - reply_to: Optional[str] = None, - interest_value: Optional[float] = None, - key_words: Optional[str] = None, - key_words_lite: Optional[str] = None, - is_mentioned: Optional[bool] = None, - is_at: Optional[bool] = None, - reply_probability_boost: Optional[float] = None, - processed_plain_text: Optional[str] = None, - display_message: Optional[str] = None, - priority_mode: Optional[str] = None, - priority_info: Optional[str] = None, - additional_config: Optional[str] = None, - is_emoji: bool = False, - is_picid: bool = False, - is_command: bool = False, - intercept_message_level: int = 0, - is_notify: bool = False, - selected_expressions: Optional[str] = None, - user_id: str = "", - user_nickname: str = "", - user_cardname: Optional[str] = None, - user_platform: str = "", - chat_info_group_id: Optional[str] = None, - chat_info_group_name: Optional[str] = None, - chat_info_group_platform: Optional[str] = None, - chat_info_user_id: str = "", - chat_info_user_nickname: str = "", - chat_info_user_cardname: Optional[str] = None, - chat_info_user_platform: str = "", - chat_info_stream_id: str = "", - chat_info_platform: str = "", - chat_info_create_time: float = 0.0, - chat_info_last_active_time: float = 0.0, - **kwargs: Any, - ): - self.message_id = message_id - self.time = time - self.chat_id = chat_id - self.reply_to = reply_to - self.interest_value = interest_value +# @dataclass(init=False) +# class DatabaseMessages(BaseDataModel): +# def __init__( +# self, +# message_id: str = "", +# time: float = 0.0, +# chat_id: str = "", +# reply_to: Optional[str] = None, +# interest_value: Optional[float] = None, +# key_words: Optional[str] = None, +# key_words_lite: Optional[str] = None, +# is_mentioned: Optional[bool] = None, +# is_at: Optional[bool] = None, +# reply_probability_boost: Optional[float] = None, +# processed_plain_text: Optional[str] = None, +# display_message: Optional[str] = None, +# priority_mode: Optional[str] = None, +# priority_info: Optional[str] = None, +# additional_config: Optional[str] = None, +# is_emoji: bool = False, +# is_picid: bool = False, +# is_command: bool = False, +# intercept_message_level: int = 0, +# is_notify: bool = False, +# selected_expressions: Optional[str] = None, +# user_id: str = "", +# user_nickname: str = "", +# user_cardname: Optional[str] = None, +# user_platform: str = "", +# chat_info_group_id: Optional[str] = None, +# chat_info_group_name: Optional[str] = None, +# chat_info_group_platform: Optional[str] = None, +# chat_info_user_id: str = "", +# chat_info_user_nickname: str = "", +# chat_info_user_cardname: Optional[str] = None, +# chat_info_user_platform: str = "", +# chat_info_stream_id: str = "", +# chat_info_platform: str = "", +# chat_info_create_time: float = 0.0, +# chat_info_last_active_time: float = 0.0, +# **kwargs: Any, +# ): +# self.message_id = message_id +# self.time = time +# self.chat_id = chat_id +# self.reply_to = reply_to +# self.interest_value = interest_value - self.key_words = key_words - self.key_words_lite = key_words_lite - self.is_mentioned = is_mentioned +# self.key_words = key_words +# self.key_words_lite = key_words_lite +# self.is_mentioned = is_mentioned - self.is_at = is_at - self.reply_probability_boost = reply_probability_boost +# self.is_at = is_at +# self.reply_probability_boost = reply_probability_boost - self.processed_plain_text = processed_plain_text - self.display_message = display_message +# self.processed_plain_text = processed_plain_text +# self.display_message = display_message - self.priority_mode = priority_mode - self.priority_info = priority_info +# self.priority_mode = priority_mode +# self.priority_info = priority_info - self.additional_config = additional_config - self.is_emoji = is_emoji - self.is_picid = is_picid - self.is_command = is_command - self.intercept_message_level = intercept_message_level - self.is_notify = is_notify +# self.additional_config = additional_config +# self.is_emoji = is_emoji +# self.is_picid = is_picid +# self.is_command = is_command +# self.intercept_message_level = intercept_message_level +# self.is_notify = is_notify - self.selected_expressions = selected_expressions +# self.selected_expressions = selected_expressions - self.group_info: Optional[DatabaseGroupInfo] = None - self.user_info = DatabaseUserInfo( - user_id=user_id, - user_nickname=user_nickname, - user_cardname=user_cardname, - platform=user_platform, - ) - if chat_info_group_id and chat_info_group_name: - self.group_info = DatabaseGroupInfo( - group_id=chat_info_group_id, - group_name=chat_info_group_name, - group_platform=chat_info_group_platform, - ) +# self.group_info: Optional[DatabaseGroupInfo] = None +# self.user_info = DatabaseUserInfo( +# user_id=user_id, +# user_nickname=user_nickname, +# user_cardname=user_cardname, +# platform=user_platform, +# ) +# if chat_info_group_id and chat_info_group_name: +# self.group_info = DatabaseGroupInfo( +# group_id=chat_info_group_id, +# group_name=chat_info_group_name, +# group_platform=chat_info_group_platform, +# ) - self.chat_info = DatabaseChatInfo( - stream_id=chat_info_stream_id, - platform=chat_info_platform, - create_time=chat_info_create_time, - last_active_time=chat_info_last_active_time, - user_info=DatabaseUserInfo( - user_id=chat_info_user_id, - user_nickname=chat_info_user_nickname, - user_cardname=chat_info_user_cardname, - platform=chat_info_user_platform, - ), - group_info=self.group_info, - ) +# self.chat_info = DatabaseChatInfo( +# stream_id=chat_info_stream_id, +# platform=chat_info_platform, +# create_time=chat_info_create_time, +# last_active_time=chat_info_last_active_time, +# user_info=DatabaseUserInfo( +# user_id=chat_info_user_id, +# user_nickname=chat_info_user_nickname, +# user_cardname=chat_info_user_cardname, +# platform=chat_info_user_platform, +# ), +# group_info=self.group_info, +# ) - if kwargs: - for key, value in kwargs.items(): - setattr(self, key, value) +# if kwargs: +# for key, value in kwargs.items(): +# setattr(self, key, value) - # def __post_init__(self): - # assert isinstance(self.message_id, str), "message_id must be a string" - # assert isinstance(self.time, float), "time must be a float" - # assert isinstance(self.chat_id, str), "chat_id must be a string" - # assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None" - # assert isinstance(self.interest_value, float) or self.interest_value is None, ( - # "interest_value must be a float or None" - # ) - def flatten(self) -> Dict[str, Any]: - """ - 将消息数据模型转换为字典格式,便于存储或传输 - """ - return { - "message_id": self.message_id, - "time": self.time, - "chat_id": self.chat_id, - "reply_to": self.reply_to, - "interest_value": self.interest_value, - "key_words": self.key_words, - "key_words_lite": self.key_words_lite, - "is_mentioned": self.is_mentioned, - "is_at": self.is_at, - "reply_probability_boost": self.reply_probability_boost, - "processed_plain_text": self.processed_plain_text, - "display_message": self.display_message, - "priority_mode": self.priority_mode, - "priority_info": self.priority_info, - "additional_config": self.additional_config, - "is_emoji": self.is_emoji, - "is_picid": self.is_picid, - "is_command": self.is_command, - "intercept_message_level": self.intercept_message_level, - "is_notify": self.is_notify, - "selected_expressions": self.selected_expressions, - "user_id": self.user_info.user_id, - "user_nickname": self.user_info.user_nickname, - "user_cardname": self.user_info.user_cardname, - "user_platform": self.user_info.platform, - "chat_info_group_id": self.group_info.group_id if self.group_info else None, - "chat_info_group_name": self.group_info.group_name if self.group_info else None, - "chat_info_group_platform": self.group_info.group_platform if self.group_info else None, - "chat_info_stream_id": self.chat_info.stream_id, - "chat_info_platform": self.chat_info.platform, - "chat_info_create_time": self.chat_info.create_time, - "chat_info_last_active_time": self.chat_info.last_active_time, - "chat_info_user_platform": self.chat_info.user_info.platform, - "chat_info_user_id": self.chat_info.user_info.user_id, - "chat_info_user_nickname": self.chat_info.user_info.user_nickname, - "chat_info_user_cardname": self.chat_info.user_info.user_cardname, - } +# # def __post_init__(self): +# # assert isinstance(self.message_id, str), "message_id must be a string" +# # assert isinstance(self.time, float), "time must be a float" +# # assert isinstance(self.chat_id, str), "chat_id must be a string" +# # assert isinstance(self.reply_to, str) or self.reply_to is None, "reply_to must be a string or None" +# # assert isinstance(self.interest_value, float) or self.interest_value is None, ( +# # "interest_value must be a float or None" +# # ) +# def flatten(self) -> Dict[str, Any]: +# """ +# 将消息数据模型转换为字典格式,便于存储或传输 +# """ +# return { +# "message_id": self.message_id, +# "time": self.time, +# "chat_id": self.chat_id, +# "reply_to": self.reply_to, +# "interest_value": self.interest_value, +# "key_words": self.key_words, +# "key_words_lite": self.key_words_lite, +# "is_mentioned": self.is_mentioned, +# "is_at": self.is_at, +# "reply_probability_boost": self.reply_probability_boost, +# "processed_plain_text": self.processed_plain_text, +# "display_message": self.display_message, +# "priority_mode": self.priority_mode, +# "priority_info": self.priority_info, +# "additional_config": self.additional_config, +# "is_emoji": self.is_emoji, +# "is_picid": self.is_picid, +# "is_command": self.is_command, +# "intercept_message_level": self.intercept_message_level, +# "is_notify": self.is_notify, +# "selected_expressions": self.selected_expressions, +# "user_id": self.user_info.user_id, +# "user_nickname": self.user_info.user_nickname, +# "user_cardname": self.user_info.user_cardname, +# "user_platform": self.user_info.platform, +# "chat_info_group_id": self.group_info.group_id if self.group_info else None, +# "chat_info_group_name": self.group_info.group_name if self.group_info else None, +# "chat_info_group_platform": self.group_info.group_platform if self.group_info else None, +# "chat_info_stream_id": self.chat_info.stream_id, +# "chat_info_platform": self.chat_info.platform, +# "chat_info_create_time": self.chat_info.create_time, +# "chat_info_last_active_time": self.chat_info.last_active_time, +# "chat_info_user_platform": self.chat_info.user_info.platform, +# "chat_info_user_id": self.chat_info.user_info.user_id, +# "chat_info_user_nickname": self.chat_info.user_info.user_nickname, +# "chat_info_user_cardname": self.chat_info.user_info.user_cardname, +# } @dataclass(init=False) diff --git a/src/common/data_models/expression_data_model.py b/src/common/data_models/expression_data_model.py index 192cbc58..7b4ad9ef 100644 --- a/src/common/data_models/expression_data_model.py +++ b/src/common/data_models/expression_data_model.py @@ -11,7 +11,6 @@ from . import BaseDatabaseDataModel class MaiExpression(BaseDatabaseDataModel[Expression]): def __init__( self, - item_id: int, situation: str, style: str, # context: str, @@ -20,6 +19,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]): count: int, last_active_time: datetime, create_time: datetime, + item_id: Optional[int] = None, session_id: Optional[str] = None, checked: bool = False, rejected: bool = False, @@ -55,7 +55,7 @@ class MaiExpression(BaseDatabaseDataModel[Expression]): if not isinstance(item, str): raise ValueError(f"Content item must be a string, got {type(item)}") return cls( - item_id=db_record.id, # type: ignore + item_id=db_record.id, situation=db_record.situation, style=db_record.style, # context=db_record.context, @@ -74,7 +74,6 @@ class MaiExpression(BaseDatabaseDataModel[Expression]): if not isinstance(item, str): raise ValueError(f"Content item must be a string, got {type(item)}") return Expression( - id=self.item_id, situation=self.situation, style=self.style, # context=self.context, diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index 13e53eb3..ce0781e1 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -7,13 +7,13 @@ if TYPE_CHECKING: from src.core.types import ActionInfo -@dataclass -class TargetPersonInfo(BaseDataModel): - platform: str = field(default_factory=str) - user_id: str = field(default_factory=str) - user_nickname: str = field(default_factory=str) - person_id: Optional[str] = None - person_name: Optional[str] = None +# @dataclass +# class TargetPersonInfo(BaseDataModel): +# platform: str = field(default_factory=str) +# user_id: str = field(default_factory=str) +# user_nickname: str = field(default_factory=str) +# person_id: Optional[str] = None +# person_name: Optional[str] = None @dataclass diff --git a/src/common/data_models/jargon_data_model.py b/src/common/data_models/jargon_data_model.py index ef4deeff..b0fff82e 100644 --- a/src/common/data_models/jargon_data_model.py +++ b/src/common/data_models/jargon_data_model.py @@ -1,9 +1,14 @@ -from typing import Optional +from typing import Optional, Dict + +import json from src.common.database.database_model import Jargon +from src.common.logger import get_logger from . import BaseDatabaseDataModel +logger = get_logger("jargon_data_model") + class MaiJargon(BaseDatabaseDataModel[Jargon]): """Jargon 数据模型,与数据库模型 Jargon 互转。""" @@ -12,28 +17,37 @@ class MaiJargon(BaseDatabaseDataModel[Jargon]): self, content: str, meaning: str, + item_id: Optional[int] = None, raw_content: Optional[str] = None, - session_id: Optional[str] = None, + session_id_list: Optional[Dict[str, int]] = None, count: int = 0, is_jargon: Optional[bool] = True, is_complete: bool = False, + is_global: bool = False, + last_inference_count: int = 0, inference_with_context: Optional[str] = None, inference_with_content_only: Optional[str] = None, ): + self.item_id = item_id + """自增主键ID""" self.content = content """黑话内容""" self.raw_content = raw_content """原始内容,未处理的黑话内容""" self.meaning = meaning """黑话含义""" - self.session_id = session_id - """会话ID,区分是否为全局黑话""" + self.session_id_list = session_id_list or {} + """会话ID字典,区分是否为全局黑话,格式为{"session_id": session_count, ...},如果为空表示全局黑话""" self.count = count """使用次数""" self.is_jargon = is_jargon """是否为黑话,False表示为白话""" self.is_complete = is_complete """是否为已经完成全部推断(count > 100后不再推断)""" + self.is_global = is_global + """是否为全局黑话(独立于session_id_dict)""" + self.last_inference_count = last_inference_count + """上一次进行推断时的count值,用于判断是否需要重新推断""" self.inference_with_context = inference_with_context """带上下文的推断结果,JSON格式""" self.inference_with_content_only = inference_with_content_only @@ -42,28 +56,40 @@ class MaiJargon(BaseDatabaseDataModel[Jargon]): @classmethod def from_db_instance(cls, db_record: Jargon) -> "MaiJargon": """从数据库模型创建 MaiJargon 实例。""" + json_list: Dict[str, int] = {} + try: + # 解析存储的字符串为字典 + json_list = json.loads(db_record.session_id_dict) + except Exception as e: + logger.error(f"Error parsing session_id_list: {e}") return cls( + item_id=db_record.id, content=db_record.content, meaning=db_record.meaning, raw_content=db_record.raw_content, - session_id=db_record.session_id, + session_id_list=json_list, count=db_record.count, is_jargon=db_record.is_jargon, is_complete=db_record.is_complete, + is_global=db_record.is_global, + last_inference_count=db_record.last_inference_count, inference_with_context=db_record.inference_with_context, inference_with_content_only=db_record.inference_with_content_only, ) def to_db_instance(self) -> Jargon: """将 MaiJargon 转换为数据库模型 Jargon。""" + dumped_session_id_list = json.dumps(self.session_id_list) return Jargon( content=self.content, raw_content=self.raw_content, meaning=self.meaning, - session_id=self.session_id, + session_id_dict=dumped_session_id_list, count=self.count, is_jargon=self.is_jargon, is_complete=self.is_complete, + is_global=self.is_global, + last_inference_count=self.last_inference_count, inference_with_context=self.inference_with_context, inference_with_content_only=self.inference_with_content_only, ) diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py index 63a30dc0..3f007643 100644 --- a/src/common/data_models/message_data_model.py +++ b/src/common/data_models/message_data_model.py @@ -1,79 +1,79 @@ -from dataclasses import dataclass -from enum import Enum -from typing import Any, Iterable, List, Optional, Tuple, Union +# from dataclasses import dataclass +# from enum import Enum +# from typing import Any, Iterable, List, Optional, Tuple, Union -from . import BaseDataModel +# from . import BaseDataModel -class ReplyContentType(Enum): - TEXT = "text" - IMAGE = "image" - EMOJI = "emoji" - COMMAND = "command" - VOICE = "voice" - HYBRID = "hybrid" - FORWARD = "forward" +# class ReplyContentType(Enum): +# TEXT = "text" +# IMAGE = "image" +# EMOJI = "emoji" +# COMMAND = "command" +# VOICE = "voice" +# HYBRID = "hybrid" +# FORWARD = "forward" - def __str__(self) -> str: - return self.value +# def __str__(self) -> str: +# return self.value -@dataclass -class ReplyContent: - content_type: ReplyContentType | str - content: Any +# @dataclass +# class ReplyContent: +# content_type: ReplyContentType | str +# content: Any -@dataclass -class ForwardNode: - user_id: Optional[str] = None - user_nickname: Optional[str] = None - content: Union[str, List[ReplyContent], None] = None +# @dataclass +# class ForwardNode: +# user_id: Optional[str] = None +# user_nickname: Optional[str] = None +# content: Union[str, List[ReplyContent], None] = None - @classmethod - def construct_as_id_reference(cls, message_id: str) -> "ForwardNode": - return cls(content=message_id) +# @classmethod +# def construct_as_id_reference(cls, message_id: str) -> "ForwardNode": +# return cls(content=message_id) - @classmethod - def construct_as_created_node( - cls, - user_id: str, - user_nickname: str, - content: List[ReplyContent], - ) -> "ForwardNode": - return cls(user_id=user_id, user_nickname=user_nickname, content=content) +# @classmethod +# def construct_as_created_node( +# cls, +# user_id: str, +# user_nickname: str, +# content: List[ReplyContent], +# ) -> "ForwardNode": +# return cls(user_id=user_id, user_nickname=user_nickname, content=content) -class ReplySetModel(BaseDataModel): - def __init__(self) -> None: - self.reply_data: List[ReplyContent] = [] +# class ReplySetModel(BaseDataModel): +# def __init__(self) -> None: +# self.reply_data: List[ReplyContent] = [] - def __len__(self) -> int: - return len(self.reply_data) +# def __len__(self) -> int: +# return len(self.reply_data) - def add_text_content(self, text: str) -> None: - self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text)) +# def add_text_content(self, text: str) -> None: +# self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text)) - def add_voice_content(self, voice_base64: str) -> None: - self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64)) +# def add_voice_content(self, voice_base64: str) -> None: +# self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64)) - def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None: - hybrid_contents: List[ReplyContent] = [] - for content_type, content in message_tuple_list: - hybrid_contents.append( - ReplyContent(content_type=self._normalize_content_type(content_type), content=content) - ) - self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents)) +# def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None: +# hybrid_contents: List[ReplyContent] = [] +# for content_type, content in message_tuple_list: +# hybrid_contents.append( +# ReplyContent(content_type=self._normalize_content_type(content_type), content=content) +# ) +# self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents)) - def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None: - self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes)) +# def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None: +# self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes)) - @staticmethod - def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str: - if isinstance(content_type, ReplyContentType): - return content_type - if isinstance(content_type, str): - for item in ReplyContentType: - if item.value == content_type: - return item - return content_type +# @staticmethod +# def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str: +# if isinstance(content_type, ReplyContentType): +# return content_type +# if isinstance(content_type, str): +# for item in ReplyContentType: +# if item.value == content_type: +# return item +# return content_type diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 1e12604d..a0993a77 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -201,14 +201,16 @@ class Jargon(SQLModel, table=True): id: Optional[int] = Field(default=None, primary_key=True) # 自增主键 content: str = Field(index=True, max_length=255, primary_key=True) # 黑话内容 - raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容 + raw_content: Optional[str] = Field(default=None, nullable=True) # 原始内容,未处理的黑话内容,为List[str] meaning: str # 黑话含义 - session_id: Optional[str] = Field(default=None, max_length=255, nullable=True) # 会话ID,区分是否为全局黑话 + session_id_dict: str = Field(default=r"{}") # 会话ID列表,格式为{"session_id": session_count, ...} count: int = Field(default=0) # 使用次数 is_jargon: Optional[bool] = Field(default=True) # 是否为黑话,False表示为白话 is_complete: bool = Field(default=False) # 是否为已经完成全部推断(count > 100后不再推断) + is_global: bool = Field(default=False) # 是否为全局黑话(独立于session_id_dict) + last_inference_count: int = Field(default=0) # 上一次进行推断时的count值,用于判断是否需要重新推断 inference_with_context: Optional[str] = Field(default=None, nullable=True) # 带上下文的推断结果,JSON格式 inference_with_content_only: Optional[str] = Field(default=None, nullable=True) # 只基于词条的推断结果,JSON格式 diff --git a/src/dream/tools/search_jargon_tool.py b/src/dream/tools/search_jargon_tool.py index 139536ac..8da8fe77 100644 --- a/src/dream/tools/search_jargon_tool.py +++ b/src/dream/tools/search_jargon_tool.py @@ -4,7 +4,7 @@ from src.common.logger import get_logger from src.common.database.database_model import Jargon from src.config.config import global_config from src.chat.utils.utils import parse_keywords_string -from src.bw_learner.learner_utils import parse_chat_id_list, chat_id_list_contains +from src.bw_learner.learner_utils_old import parse_chat_id_list, chat_id_list_contains logger = get_logger("dream_agent") diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index 7dabdd98..dceb339a 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -59,15 +59,15 @@ class TopicCacheItem: class ChatHistorySummarizer: """聊天内容概括器""" - def __init__(self, chat_id: str, check_interval: int = 60): + def __init__(self, session_id: str, check_interval: int = 60): """ 初始化聊天内容概括器 Args: - chat_id: 聊天ID + session_id: 会话ID check_interval: 定期检查间隔(秒),默认60秒 """ - self.chat_id = chat_id + self.session_id = session_id self._chat_display_name = self._get_chat_display_name() self.log_prefix = f"[{self._chat_display_name}]" @@ -83,7 +83,7 @@ class ChatHistorySummarizer: # 话题缓存:topic_str -> TopicCacheItem # 在内存中维护,并通过本地文件实时持久化 self.topic_cache: Dict[str, TopicCacheItem] = {} - self._safe_chat_id = self._sanitize_chat_id(self.chat_id) + self._safe_chat_id = self._sanitize_chat_id(self.session_id) self._topic_cache_file = HIPPO_CACHE_DIR / f"{self._safe_chat_id}.json" # 注意:批次加载需要异步查询消息,所以在 start() 中调用 @@ -104,14 +104,14 @@ class ChatHistorySummarizer: if chat_name: return chat_name # 如果获取失败,使用简化的chat_id显示 - if len(self.chat_id) > 20: - return f"{self.chat_id[:8]}..." - return self.chat_id + if len(self.session_id) > 20: + return f"{self.session_id[:8]}..." + return self.session_id except Exception: # 如果获取失败,使用简化的chat_id显示 - if len(self.chat_id) > 20: - return f"{self.chat_id[:8]}..." - return self.chat_id + if len(self.session_id) > 20: + return f"{self.session_id[:8]}..." + return self.session_id def _sanitize_chat_id(self, chat_id: str) -> str: """用于生成可作为文件名的 chat_id""" @@ -163,7 +163,7 @@ class ChatHistorySummarizer: # 根据时间范围重新查询消息 messages = message_api.get_messages_by_time_in_chat( - chat_id=self.chat_id, + chat_id=self.session_id, start_time=start_time, end_time=end_time, limit=0, @@ -193,7 +193,7 @@ class ChatHistorySummarizer: HIPPO_CACHE_DIR.mkdir(parents=True, exist_ok=True) data = { - "chat_id": self.chat_id, + "chat_id": self.session_id, "last_topic_check_time": self.last_topic_check_time, "topics": { topic: { @@ -230,7 +230,7 @@ class ChatHistorySummarizer: try: # 获取从上次检查时间到当前时间的新消息 new_messages = message_api.get_messages_by_time_in_chat( - chat_id=self.chat_id, + chat_id=self.session_id, start_time=self.last_check_time, end_time=current_time, limit=0, @@ -917,7 +917,7 @@ class ChatHistorySummarizer: # 准备数据 data = { - "chat_id": self.chat_id, + "chat_id": self.session_id, "start_time": start_time, "end_time": end_time, "original_text": original_text, diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index a14a9c67..745acb28 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -14,7 +14,7 @@ from src.common.database.database_model import ThinkingQuestion from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon +from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") diff --git a/src/memory_system/retrieval_tools/query_words.py b/src/memory_system/retrieval_tools/query_words.py index 9bdf8ba2..66fb3c46 100644 --- a/src/memory_system/retrieval_tools/query_words.py +++ b/src/memory_system/retrieval_tools/query_words.py @@ -4,7 +4,7 @@ """ from src.common.logger import get_logger -from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon +from src.bw_learner.jargon_explainer_old import retrieve_concepts_with_jargon from .tool_registry import register_memory_retrieval_tool logger = get_logger("memory_retrieval_tools")