From 958d6e04eed5a0ff113c0f46813646f8492c4574 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 11 Oct 2025 02:03:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E8=A1=A8=E8=BE=BE=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=EF=BC=8C=E7=8E=B0=E5=9C=A8=E4=BC=9A=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E6=9C=B4=E7=B4=A0=E8=B4=9D=E5=8F=B6=E6=96=AF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E6=9D=A5=E9=A2=84=E6=B5=8B=E4=BD=BF=E7=94=A8=E4=BB=80?= =?UTF-8?q?=E4=B9=88=E8=A1=A8=E8=BE=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- src/chat/brain_chat/brain_chat.py | 2 +- src/chat/express/expression_selector.py | 316 --------- src/chat/heart_flow/heartFC_chat.py | 7 +- src/chat/replyer/group_generator.py | 6 +- src/chat/replyer/private_generator.py | 6 +- src/common/database/database_model.py | 3 +- src/config/official_configs.py | 4 +- src/{chat => }/express/expression_learner.py | 226 ++++--- src/express/expression_selector.py | 520 +++++++++++++++ src/express/expressor_model/model.py | 131 ++++ src/express/expressor_model/online_nb.py | 60 ++ src/express/expressor_model/tokenizer.py | 28 + src/express/style_learner.py | 628 ++++++++++++++++++ template/bot_config_template.toml | 6 +- test_expression_selector_prediction.py | 152 +++++ ..._expression_style_situation_integration.py | 188 ++++++ test_style_learner_db.py | 391 +++++++++++ view_pkl.py | 76 +++ view_tokens.py | 63 ++ 20 files changed, 2372 insertions(+), 443 deletions(-) delete mode 100644 src/chat/express/expression_selector.py rename src/{chat => }/express/expression_learner.py (75%) create mode 100644 src/express/expression_selector.py create mode 100644 src/express/expressor_model/model.py create mode 100644 src/express/expressor_model/online_nb.py create mode 100644 src/express/expressor_model/tokenizer.py create mode 100644 src/express/style_learner.py create mode 100644 test_expression_selector_prediction.py create mode 100644 test_expression_style_situation_integration.py create mode 100644 test_style_learner_db.py create mode 100644 view_pkl.py create mode 100644 view_tokens.py diff --git a/.gitignore b/.gitignore index 8d07a009..e43df6d2 100644 --- a/.gitignore +++ b/.gitignore @@ -323,7 +323,7 @@ run_pet.bat !/plugins/emoji_manage_plugin !/plugins/take_picture_plugin !/plugins/deep_think -!/plugins/MaiFrequencyControl +!/plugins/BetterFrequency !/plugins/__init__.py config.toml diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 1f5e1767..b5d2cec7 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -16,7 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail -from src.chat.express.expression_learner import expression_learner_manager +from src.express.expression_learner import expression_learner_manager from src.person_info.person_info import Person from src.plugin_system.base.component_types import EventType, ActionInfo from src.plugin_system.core import events_manager diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py deleted file mode 100644 index dbb40f6a..00000000 --- a/src/chat/express/expression_selector.py +++ /dev/null @@ -1,316 +0,0 @@ -import json -import time -import random -import hashlib - -from typing import List, Dict, Optional, Any, Tuple -from json_repair import repair_json - -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config, model_config -from src.common.logger import get_logger -from src.common.database.database_model import Expression -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager - -logger = get_logger("expression_selector") - - -def init_prompt(): - expression_evaluation_prompt = """ -以下是正在进行的聊天内容: -{chat_observe_info} - -你的名字是{bot_name}{target_message} - -以下是可选的表达情境: -{all_situations} - -请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。 -考虑因素包括: -1. 聊天的情绪氛围(轻松、严肃、幽默等) -2. 话题类型(日常、技术、游戏、情感等) -3. 情境与当前语境的匹配度 -{target_message_extra_block} - -请以JSON格式输出,只需要输出选中的情境编号: -例如: -{{ - "selected_situations": [2, 3, 5, 7, 19] -}} - -请严格按照JSON格式输出,不要包含其他内容: -""" - Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") - - -def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]: - """按权重随机抽样""" - if not population or not weights or k <= 0: - return [] - - if len(population) <= k: - return population.copy() - - # 使用累积权重的方法进行加权抽样 - selected = [] - population_copy = population.copy() - weights_copy = weights.copy() - - for _ in range(k): - if not population_copy: - break - - # 选择一个元素 - chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0] - selected.append(population_copy.pop(chosen_idx)) - weights_copy.pop(chosen_idx) - - return selected - - -class ExpressionSelector: - def __init__(self): - self.llm_model = LLMRequest( - model_set=model_config.model_task_config.utils_small, request_type="expression.selector" - ) - - def can_use_expression_for_chat(self, chat_id: str) -> bool: - """ - 检查指定聊天流是否允许使用表达 - - Args: - chat_id: 聊天流ID - - Returns: - bool: 是否允许使用表达 - """ - try: - use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id) - return use_expression - except Exception as e: - logger.error(f"检查表达使用权限失败: {e}") - return False - - @staticmethod - def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: - """解析'platform:id:type'为chat_id(与get_stream_id一致)""" - try: - parts = stream_config_str.split(":") - if len(parts) != 3: - return None - platform = parts[0] - id_str = parts[1] - stream_type = parts[2] - is_group = stream_type == "group" - if is_group: - components = [platform, str(id_str)] - else: - components = [platform, str(id_str), "private"] - key = "_".join(components) - return hashlib.md5(key.encode()).hexdigest() - except Exception: - return None - - def get_related_chat_ids(self, chat_id: str) -> List[str]: - """根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)""" - groups = global_config.expression.expression_groups - - # 检查是否存在全局共享组(包含"*"的组) - global_group_exists = any("*" in group for group in groups) - - if global_group_exists: - # 如果存在全局共享组,则返回所有可用的chat_id - all_chat_ids = set() - for group in groups: - for stream_config_str in group: - if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): - all_chat_ids.add(chat_id_candidate) - return list(all_chat_ids) if all_chat_ids else [chat_id] - - # 否则使用现有的组逻辑 - for group in groups: - group_chat_ids = [] - for stream_config_str in group: - if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): - group_chat_ids.append(chat_id_candidate) - if chat_id in group_chat_ids: - return group_chat_ids - return [chat_id] - - def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: - # sourcery skip: extract-duplicate-method, move-assign - # 支持多chat_id合并抽选 - related_chat_ids = self.get_related_chat_ids(chat_id) - - # 优化:一次性查询所有相关chat_id的表达方式 - style_query = Expression.select().where( - (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") - ) - - style_exprs = [ - { - "id": expr.id, - "situation": expr.situation, - "style": expr.style, - "count": expr.count, - "last_active_time": expr.last_active_time, - "source_id": expr.chat_id, - "type": "style", - "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } - for expr in style_query - ] - - # 按权重抽样(使用count作为权重) - if style_exprs: - style_weights = [expr.get("count", 1) for expr in style_exprs] - selected_style = weighted_sample(style_exprs, style_weights, total_num) - else: - selected_style = [] - return selected_style - - def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1): - """对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库""" - if not expressions_to_update: - return - updates_by_key = {} - for expr in expressions_to_update: - source_id: str = expr.get("source_id") # type: ignore - expr_type: str = expr.get("type", "style") - situation: str = expr.get("situation") # type: ignore - style: str = expr.get("style") # type: ignore - if not source_id or not situation or not style: - logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") - continue - key = (source_id, expr_type, situation, style) - if key not in updates_by_key: - updates_by_key[key] = expr - for chat_id, expr_type, situation, style in updates_by_key: - query = Expression.select().where( - (Expression.chat_id == chat_id) - & (Expression.type == expr_type) - & (Expression.situation == situation) - & (Expression.style == style) - ) - if query.exists(): - expr_obj = query.get() - current_count = expr_obj.count - new_count = min(current_count + increment, 5.0) - expr_obj.count = new_count - expr_obj.last_active_time = time.time() - expr_obj.save() - logger.debug( - f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db" - ) - - async def select_suitable_expressions_llm( - self, - chat_id: str, - chat_info: str, - max_num: int = 10, - target_message: Optional[str] = None, - ) -> Tuple[List[Dict[str, Any]], List[int]]: - # sourcery skip: inline-variable, list-comprehension - """使用LLM选择适合的表达方式""" - - # 检查是否允许在此聊天流中使用表达 - if not self.can_use_expression_for_chat(chat_id): - logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") - return [], [] - - # 1. 获取20个随机表达方式(现在按权重抽取) - style_exprs = self.get_random_expressions(chat_id, 20) - - if len(style_exprs) < 10: - logger.info(f"聊天流 {chat_id} 表达方式正在积累中") - return [], [] - - # 2. 构建所有表达方式的索引和情境列表 - all_expressions: List[Dict[str, Any]] = [] - all_situations: List[str] = [] - - # 添加style表达方式 - for expr in style_exprs: - expr = expr.copy() - all_expressions.append(expr) - all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}") - - if not all_expressions: - logger.warning("没有找到可用的表达方式") - return [], [] - - all_situations_str = "\n".join(all_situations) - - if target_message: - target_message_str = f",现在你想要回复消息:{target_message}" - target_message_extra_block = "4.考虑你要回复的目标消息" - else: - target_message_str = "" - target_message_extra_block = "" - - # 3. 构建prompt(只包含情境,不包含完整的表达方式) - prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format( - bot_name=global_config.bot.nickname, - chat_observe_info=chat_info, - all_situations=all_situations_str, - max_num=max_num, - target_message=target_message_str, - target_message_extra_block=target_message_extra_block, - ) - - # 4. 调用LLM - try: - # start_time = time.time() - content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) - # logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}") - - # logger.info(f"模型名称: {model_name}") - # logger.info(f"LLM返回结果: {content}") - # if reasoning_content: - # logger.info(f"LLM推理: {reasoning_content}") - # else: - # logger.info(f"LLM推理: 无") - - if not content: - logger.warning("LLM返回空结果") - return [], [] - - # 5. 解析结果 - result = repair_json(content) - if isinstance(result, str): - result = json.loads(result) - - if not isinstance(result, dict) or "selected_situations" not in result: - logger.error("LLM返回格式错误") - logger.info(f"LLM返回结果: \n{content}") - return [], [] - - selected_indices = result["selected_situations"] - - # 根据索引获取完整的表达方式 - valid_expressions: List[Dict[str, Any]] = [] - selected_ids = [] - for idx in selected_indices: - if isinstance(idx, int) and 1 <= idx <= len(all_expressions): - expression = all_expressions[idx - 1] # 索引从1开始 - selected_ids.append(expression["id"]) - valid_expressions.append(expression) - - # 对选中的所有表达方式,一次性更新count数 - if valid_expressions: - self.update_expressions_count_batch(valid_expressions, 0.006) - - # logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个") - return valid_expressions, selected_ids - - except Exception as e: - logger.error(f"LLM处理表达方式选择时出错: {e}") - return [], [] - - -init_prompt() - -try: - expression_selector = ExpressionSelector() -except Exception as e: - logger.error(f"ExpressionSelector初始化失败: {e}") diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 6de7e050..b5aa6aff 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -18,7 +18,7 @@ from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail from src.chat.heart_flow.hfc_utils import send_typing, stop_typing -from src.chat.express.expression_learner import expression_learner_manager +from src.express.expression_learner import expression_learner_manager from src.chat.frequency_control.frequency_control import frequency_control_manager from src.memory_system.question_maker import QuestionMaker from src.memory_system.questions import global_conflict_tracker @@ -331,9 +331,8 @@ class HeartFChatting: async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): - await self.expression_learner.trigger_learning_for_chat() - - await global_memory_chest.build_running_content(chat_id=self.stream_id) + asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) + asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id)) cycle_timers, thinking_id = self.start_cycle() diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 22684c03..88935da7 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, replace_user_references, ) -from src.chat.express.expression_selector import expression_selector +from src.express.expression_selector import expression_selector from src.plugin_system.apis.message_api import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator @@ -238,8 +238,8 @@ class DefaultReplyer: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm( + # 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择 + selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( self.chat_stream.stream_id, chat_history, max_num=8, target_message=target ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 785274d6..c7702bbb 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, replace_user_references, ) -from src.chat.express.expression_selector import expression_selector +from src.express.expression_selector import expression_selector from src.plugin_system.apis.message_api import translate_pid_to_description from src.mood.mood_manager import mood_manager @@ -256,8 +256,8 @@ class PrivateReplyer: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 - # LLM模式:调用LLM选择5-10个,然后随机选5个 - selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm( + # 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择 + selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( self.chat_stream.stream_id, chat_history, max_num=8, target_message=target ) diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 2461784f..f3efa943 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -303,11 +303,10 @@ class Expression(BaseModel): situation = TextField() style = TextField() - count = FloatField() # new mode fields context = TextField(null=True) - context_words = TextField(null=True) + up_content = TextField(null=True) last_active_time = FloatField() chat_id = TextField(index=True) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 657a14ae..df616a64 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -310,8 +310,8 @@ class MemoryConfig(ConfigBase): class ExpressionConfig(ConfigBase): """表达配置类""" - mode: Literal["llm", "context", "full-context"] = "context" - """表达方式模式,可选:llm模式,context上下文模式,full-context 完整上下文嵌入模式""" + mode: str = "classic" + """表达方式模式,可选:classic经典模式,exp_model 表达模型模式""" learning_list: list[list] = field(default_factory=lambda: []) """ diff --git a/src/chat/express/expression_learner.py b/src/express/expression_learner.py similarity index 75% rename from src/chat/express/expression_learner.py rename to src/express/expression_learner.py index 4b534fbd..56536a21 100644 --- a/src/chat/express/expression_learner.py +++ b/src/express/expression_learner.py @@ -2,10 +2,12 @@ import time import random import json import os +import re from datetime import datetime import jieba from typing import List, Dict, Optional, Any, Tuple import traceback +import difflib from src.common.logger import get_logger from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest @@ -17,16 +19,23 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager +from src.express.style_learner import style_learner_manager from json_repair import repair_json -MAX_EXPRESSION_COUNT = 300 -DECAY_DAYS = 15 # 30天衰减到0.01 -DECAY_MIN = 0.01 # 最小衰减值 +# MAX_EXPRESSION_COUNT = 300 logger = get_logger("expressor") +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + 使用SequenceMatcher计算相似度 + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + def format_create_date(timestamp: float) -> str: """ 将时间戳格式化为可读的日期字符串 @@ -173,63 +182,7 @@ class ExpressionLearner: traceback.print_exc() return False - def _apply_global_decay_to_database(self, current_time: float) -> None: - """ - 对数据库中的所有表达方式应用全局衰减 - """ - try: - # 获取所有表达方式 - all_expressions = Expression.select() - updated_count = 0 - deleted_count = 0 - - for expr in all_expressions: - # 计算时间差 - last_active = expr.last_active_time - time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - - # 计算衰减值 - decay_value = self.calculate_decay_factor(time_diff_days) - new_count = max(0.01, expr.count - decay_value) - - if new_count <= 0.01: - # 如果count太小,删除这个表达方式 - expr.delete_instance() - deleted_count += 1 - else: - # 更新count - expr.count = new_count - expr.save() - updated_count += 1 - - if updated_count > 0 or deleted_count > 0: - logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") - - except Exception as e: - logger.error(f"数据库全局衰减失败: {e}") - - def calculate_decay_factor(self, time_diff_days: float) -> float: - """ - 计算衰减值 - 当时间差为0天时,衰减值为0(最近活跃的不衰减) - 当时间差为7天时,衰减值为0.002(中等衰减) - 当时间差为30天或更长时,衰减值为0.01(高衰减) - 使用二次函数进行曲线插值 - """ - if time_diff_days <= 0: - return 0.0 # 刚激活的表达式不衰减 - - if time_diff_days >= DECAY_DAYS: - return 0.01 # 长时间未活跃的表达式大幅衰减 - - # 使用二次函数插值:在0-30天之间从0衰减到0.01 - # 使用简单的二次函数:y = a * x^2 - # 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900 - a = 0.01 / (DECAY_DAYS**2) - decay = a * (time_diff_days**2) - - return min(0.01, decay) async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]: """ @@ -247,7 +200,7 @@ class ExpressionLearner: situation, style, _context, - _context_words, + _up_content, ) in learnt_expressions: learnt_expressions_str += f"{situation}->{style}\n" @@ -260,7 +213,7 @@ class ExpressionLearner: situation, style, context, - context_words, + up_content, ) in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] @@ -269,13 +222,15 @@ class ExpressionLearner: "situation": situation, "style": style, "context": context, - "context_words": context_words, + "up_content": up_content, } ) current_time = time.time() - # 存储到数据库 Expression 表 + # 存储到数据库 Expression 表并训练 style_learner + trained_chat_ids = set() # 记录已训练的聊天室 + for chat_id, expr_list in chat_dict.items(): for new_expr in expr_list: # 查找是否已存在相似表达方式 @@ -292,32 +247,72 @@ class ExpressionLearner: expr_obj.situation = new_expr["situation"] expr_obj.style = new_expr["style"] expr_obj.context = new_expr["context"] - expr_obj.context_words = new_expr["context_words"] - expr_obj.count = expr_obj.count + 1 + expr_obj.up_content = new_expr["up_content"] expr_obj.last_active_time = current_time expr_obj.save() else: Expression.create( situation=new_expr["situation"], style=new_expr["style"], - count=1, last_active_time=current_time, chat_id=chat_id, type="style", create_date=current_time, # 手动设置创建日期 context=new_expr["context"], - context_words=new_expr["context_words"], + up_content=new_expr["up_content"], ) + + # 训练 style_learner(up_content 和 style 必定存在) + try: + # 获取 learner 实例 + learner = style_learner_manager.get_learner(chat_id) + + # 先添加风格和对应的 situation(如果存在) + if new_expr.get("situation"): + learner.add_style(new_expr["style"], new_expr["situation"]) + else: + learner.add_style(new_expr["style"]) + + # 学习映射关系 + success = style_learner_manager.learn_mapping( + chat_id, + new_expr["up_content"], + new_expr["style"] + ) + if success: + logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" + + (f" (situation: {new_expr['situation']})" if new_expr.get("situation") else "")) + trained_chat_ids.add(chat_id) + else: + logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}") + except Exception as e: + logger.error(f"StyleLearner学习异常: {chat_id} - {e}") + # 限制最大数量 - exprs = list( - Expression.select() - .where((Expression.chat_id == chat_id) & (Expression.type == "style")) - .order_by(Expression.count.asc()) - ) - if len(exprs) > MAX_EXPRESSION_COUNT: - # 删除count最小的多余表达方式 - for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: - expr.delete_instance() + # exprs = list( + # Expression.select() + # .where((Expression.chat_id == chat_id) & (Expression.type == "style")) + # .order_by(Expression.last_active_time.asc()) + # ) + # if len(exprs) > MAX_EXPRESSION_COUNT: + # 删除最久未活跃的多余表达方式 + # for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]: + # expr.delete_instance() + + # 保存训练好的 style_learner 模型 + if trained_chat_ids: + try: + logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...") + save_success = style_learner_manager.save_all_models() + + if save_success: + logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}") + else: + logger.warning("StyleLearner 模型保存失败") + + except Exception as e: + logger.error(f"StyleLearner 模型保存异常: {e}") + return learnt_expressions async def match_expression_context( @@ -339,8 +334,8 @@ class ExpressionLearner: response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) - print(f"match_expression_context_prompt: {prompt}") - print(f"random_msg_match_str: {response}") + # print(f"match_expression_context_prompt: {prompt}") + # print(f"{response}") # 解析JSON响应 match_responses = [] @@ -395,6 +390,8 @@ class ExpressionLearner: matched_expressions = [] used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引 + + print(f"match_responses: {match_responses}") for match_response in match_responses: try: @@ -418,7 +415,7 @@ class ExpressionLearner: async def learn_expression( self, num: int = 10 - ) -> Optional[List[Tuple[str, str, str, List[str]]]]: + ) -> Optional[List[Tuple[str, str, str, List[str], str]]]: """从指定聊天流学习表达方式 Args: @@ -466,39 +463,52 @@ class ExpressionLearner: matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( expressions, random_msg_match_str ) + + print(f"matched_expressions: {matched_expressions}") - split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context( - matched_expressions - ) + # 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射 + bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content) + pic_pattern = r"\[picid:[^\]]+\]" + reply_pattern = r"回复<[^:<>]+:[^:<>]+>" + at_pattern = r"@<[^:<>]+:[^:<>]+>" + for idx, msg in enumerate(random_msg): + content = msg.processed_plain_text or "" + content = re.sub(pic_pattern, "[图片]", content) + content = re.sub(reply_pattern, "回复[某人]", content) + content = re.sub(at_pattern, "@[某人]", content) + content = content.strip() + if content: + bare_lines.append((idx, content)) - split_matched_expressions_w_emb = [] - - for situation, style, context, context_words in split_matched_expressions: - split_matched_expressions_w_emb.append( - (self.chat_id, situation, style, context, context_words) - ) - - return split_matched_expressions_w_emb - - def split_expression_context( - self, matched_expressions: List[Tuple[str, str, str]] - ) -> List[Tuple[str, str, str, List[str]]]: - """ - 对matched_expressions中的context部分进行jieba分词 - - Args: - matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context) - - Returns: - 添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words) - """ - result = [] + # 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过) + filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content) for situation, style, context in matched_expressions: - # 使用jieba进行分词 - context_words = list(jieba.cut(context)) - result.append((situation, style, context, context_words)) + # 在 bare_lines 中找到第一处相似度达到85%的行 + pos = None + for i, (_, c) in enumerate(bare_lines): + similarity = calculate_similarity(c, context) + if similarity >= 0.85: # 85%相似度阈值 + pos = i + break + + if pos is None or pos == 0: + # 没有匹配到或没有上一句,跳过该表达 + continue + prev_original_idx = bare_lines[pos - 1][0] + up_content = (random_msg[prev_original_idx].processed_plain_text or "").strip() + if not up_content: + continue + filtered_with_up.append((situation, style, context, up_content)) + + if not filtered_with_up: + return None + + results: List[Tuple[str, str, str, str]] = [] + for (situation, style, context, up_content) in filtered_with_up: + results.append((self.chat_id, situation, style, context, up_content)) + + return results - return result def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: """ diff --git a/src/express/expression_selector.py b/src/express/expression_selector.py new file mode 100644 index 00000000..d6dfa7cb --- /dev/null +++ b/src/express/expression_selector.py @@ -0,0 +1,520 @@ +import json +import time +import random +import hashlib + +from typing import List, Dict, Optional, Any, Tuple +from json_repair import repair_json + +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.logger import get_logger +from src.common.database.database_model import Expression +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.express.style_learner import style_learner_manager + +logger = get_logger("expression_selector") + + +def init_prompt(): + expression_evaluation_prompt = """ +以下是正在进行的聊天内容: +{chat_observe_info} + +你的名字是{bot_name}{target_message} + +以下是可选的表达情境: +{all_situations} + +请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。 +考虑因素包括: +1. 聊天的情绪氛围(轻松、严肃、幽默等) +2. 话题类型(日常、技术、游戏、情感等) +3. 情境与当前语境的匹配度 +{target_message_extra_block} + +请以JSON格式输出,只需要输出选中的情境编号: +例如: +{{ + "selected_situations": [2, 3, 5, 7, 19] +}} + +请严格按照JSON格式输出,不要包含其他内容: +""" + Prompt(expression_evaluation_prompt, "expression_evaluation_prompt") + + +def weighted_sample(population: List[Dict], k: int) -> List[Dict]: + """随机抽样""" + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + # 使用随机抽样 + selected = [] + population_copy = population.copy() + + for _ in range(k): + if not population_copy: + break + + # 随机选择一个元素 + chosen_idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(chosen_idx)) + + return selected + + +class ExpressionSelector: + def __init__(self): + self.llm_model = LLMRequest( + model_set=model_config.model_task_config.utils_small, request_type="expression.selector" + ) + + def can_use_expression_for_chat(self, chat_id: str) -> bool: + """ + 检查指定聊天流是否允许使用表达 + + Args: + chat_id: 聊天流ID + + Returns: + bool: 是否允许使用表达 + """ + try: + use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id) + return use_expression + except Exception as e: + logger.error(f"检查表达使用权限失败: {e}") + return False + + @staticmethod + def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]: + """解析'platform:id:type'为chat_id(与get_stream_id一致)""" + try: + parts = stream_config_str.split(":") + if len(parts) != 3: + return None + platform = parts[0] + id_str = parts[1] + stream_type = parts[2] + is_group = stream_type == "group" + if is_group: + components = [platform, str(id_str)] + else: + components = [platform, str(id_str), "private"] + key = "_".join(components) + return hashlib.md5(key.encode()).hexdigest() + except Exception: + return None + + def get_related_chat_ids(self, chat_id: str) -> List[str]: + """根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)""" + groups = global_config.expression.expression_groups + + # 检查是否存在全局共享组(包含"*"的组) + global_group_exists = any("*" in group for group in groups) + + if global_group_exists: + # 如果存在全局共享组,则返回所有可用的chat_id + all_chat_ids = set() + for group in groups: + for stream_config_str in group: + if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): + all_chat_ids.add(chat_id_candidate) + return list(all_chat_ids) if all_chat_ids else [chat_id] + + # 否则使用现有的组逻辑 + for group in groups: + group_chat_ids = [] + for stream_config_str in group: + if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): + group_chat_ids.append(chat_id_candidate) + if chat_id in group_chat_ids: + return group_chat_ids + return [chat_id] + + def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]: + """ + 使用 style_learner 模型预测最合适的表达方式 + + Args: + chat_id: 聊天室ID + chat_info: 聊天内容信息 + total_num: 需要预测的数量 + + Returns: + List[Dict[str, Any]]: 预测的表达方式列表 + """ + try: + # 支持多chat_id合并预测 + related_chat_ids = self.get_related_chat_ids(chat_id) + + # 从聊天信息中提取关键内容作为预测输入 + # 这里可以进一步优化,提取更合适的预测输入 + prediction_input = self._extract_prediction_input(chat_info) + + predicted_expressions = [] + + # 为每个相关的chat_id进行预测 + for related_chat_id in related_chat_ids: + try: + # 使用 style_learner 预测最合适的风格 + best_style, scores = style_learner_manager.predict_style( + related_chat_id, prediction_input, top_k=total_num + ) + + if best_style and scores: + # 获取预测风格的完整信息 + learner = style_learner_manager.get_learner(related_chat_id) + style_id, situation = learner.get_style_info(best_style) + + if style_id and situation: + # 从数据库查找对应的表达记录 + expr_query = Expression.select().where( + (Expression.chat_id == related_chat_id) & + (Expression.type == "style") & + (Expression.situation == situation) & + (Expression.style == best_style) + ) + + if expr_query.exists(): + expr = expr_query.get() + predicted_expressions.append({ + "id": expr.id, + "situation": expr.situation, + "style": expr.style, + "last_active_time": expr.last_active_time, + "source_id": expr.chat_id, + "type": "style", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, + "prediction_score": scores.get(best_style, 0.0), + "prediction_input": prediction_input + }) + + except Exception as e: + logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}") + continue + + # 按预测分数排序,取前 total_num 个 + predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True) + selected_expressions = predicted_expressions[:total_num] + + logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式") + return selected_expressions + + except Exception as e: + logger.error(f"模型预测表达方式失败: {e}") + # 如果预测失败,回退到随机选择 + return self._fallback_random_expressions(chat_id, total_num) + + def _extract_prediction_input(self, chat_info: str) -> str: + """ + 从聊天信息中提取用于预测的关键内容 + + Args: + chat_info: 聊天内容信息 + + Returns: + str: 提取的预测输入 + """ + try: + # 简单的提取策略:取最后几句话作为预测输入 + lines = chat_info.strip().split('\n') + if not lines: + return "" + + # 取最后3行作为预测输入 + recent_lines = lines[-3:] + prediction_input = ' '.join(recent_lines).strip() + + # 如果内容太长,截取前100个字符 + if len(prediction_input) > 100: + prediction_input = prediction_input[:100] + + return prediction_input + + except Exception as e: + logger.warning(f"提取预测输入失败: {e}") + return "" + + def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]: + """ + 回退到随机选择表达方式 + + Args: + chat_id: 聊天室ID + total_num: 需要选择的数量 + + Returns: + List[Dict[str, Any]]: 随机选择的表达方式列表 + """ + try: + # 支持多chat_id合并抽选 + related_chat_ids = self.get_related_chat_ids(chat_id) + + # 优化:一次性查询所有相关chat_id的表达方式 + style_query = Expression.select().where( + (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") + ) + + style_exprs = [ + { + "id": expr.id, + "situation": expr.situation, + "style": expr.style, + "last_active_time": expr.last_active_time, + "source_id": expr.chat_id, + "type": "style", + "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, + } + for expr in style_query + ] + + # 随机抽样 + if style_exprs: + selected_style = weighted_sample(style_exprs, total_num) + else: + selected_style = [] + + logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式") + return selected_style + + except Exception as e: + logger.error(f"随机选择表达方式失败: {e}") + return [] + + + async def select_suitable_expressions( + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + target_message: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], List[int]]: + """ + 根据配置模式选择适合的表达方式 + + Args: + chat_id: 聊天流ID + chat_info: 聊天内容信息 + max_num: 最大选择数量 + target_message: 目标消息内容 + + Returns: + Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 + """ + # 检查是否允许在此聊天流中使用表达 + if not self.can_use_expression_for_chat(chat_id): + logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表") + return [], [] + + # 获取配置模式 + expression_mode = global_config.expression.mode + + if expression_mode == "exp_model": + # exp_model模式:直接使用模型预测,不经过LLM + logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式") + return await self._select_expressions_model_only(chat_id, chat_info, max_num) + elif expression_mode == "classic": + # classic模式:随机选择+LLM选择 + logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式") + return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message) + else: + logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式") + return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message) + + async def _select_expressions_model_only( + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + ) -> Tuple[List[Dict[str, Any]], List[int]]: + """ + exp_model模式:直接使用模型预测,不经过LLM + + Args: + chat_id: 聊天流ID + chat_info: 聊天内容信息 + max_num: 最大选择数量 + + Returns: + Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 + """ + try: + # 使用模型预测最合适的表达方式 + style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2) + + # if len(style_exprs) < 5: + # logger.info(f"聊天流 {chat_id} 表达方式正在积累中") + # return [], [] + + # 直接取前max_num个预测结果 + selected_expressions = style_exprs[:max_num] + selected_ids = [expr["id"] for expr in selected_expressions] + + # 更新last_active_time + if selected_expressions: + self.update_expressions_last_active_time(selected_expressions) + + logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式") + return selected_expressions, selected_ids + + except Exception as e: + logger.error(f"exp_model模式选择表达方式失败: {e}") + return [], [] + + async def _select_expressions_classic( + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + target_message: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], List[int]]: + """ + classic模式:随机选择+LLM选择 + + Args: + chat_id: 聊天流ID + chat_info: 聊天内容信息 + max_num: 最大选择数量 + target_message: 目标消息内容 + + Returns: + Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表 + """ + try: + # 1. 使用模型预测最合适的表达方式 + style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20) + + if len(style_exprs) < 10: + logger.info(f"聊天流 {chat_id} 表达方式正在积累中") + return [], [] + + # 2. 构建所有表达方式的索引和情境列表 + all_expressions: List[Dict[str, Any]] = [] + all_situations: List[str] = [] + + # 添加style表达方式 + for expr in style_exprs: + expr = expr.copy() + all_expressions.append(expr) + all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}") + + if not all_expressions: + logger.warning("没有找到可用的表达方式") + return [], [] + + all_situations_str = "\n".join(all_situations) + + if target_message: + target_message_str = f",现在你想要回复消息:{target_message}" + target_message_extra_block = "4.考虑你要回复的目标消息" + else: + target_message_str = "" + target_message_extra_block = "" + + # 3. 构建prompt(只包含情境,不包含完整的表达方式) + prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format( + bot_name=global_config.bot.nickname, + chat_observe_info=chat_info, + all_situations=all_situations_str, + max_num=max_num, + target_message=target_message_str, + target_message_extra_block=target_message_extra_block, + ) + + # 4. 调用LLM + content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt) + + if not content: + logger.warning("LLM返回空结果") + return [], [] + + # 5. 解析结果 + result = repair_json(content) + if isinstance(result, str): + result = json.loads(result) + + if not isinstance(result, dict) or "selected_situations" not in result: + logger.error("LLM返回格式错误") + logger.info(f"LLM返回结果: \n{content}") + return [], [] + + selected_indices = result["selected_situations"] + + # 根据索引获取完整的表达方式 + valid_expressions: List[Dict[str, Any]] = [] + selected_ids = [] + for idx in selected_indices: + if isinstance(idx, int) and 1 <= idx <= len(all_expressions): + expression = all_expressions[idx - 1] # 索引从1开始 + selected_ids.append(expression["id"]) + valid_expressions.append(expression) + + # 对选中的所有表达方式,更新last_active_time + if valid_expressions: + self.update_expressions_last_active_time(valid_expressions) + + logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个") + return valid_expressions, selected_ids + + except Exception as e: + logger.error(f"classic模式处理表达方式选择时出错: {e}") + return [], [] + + async def select_suitable_expressions_llm( + self, + chat_id: str, + chat_info: str, + max_num: int = 10, + target_message: Optional[str] = None, + ) -> Tuple[List[Dict[str, Any]], List[int]]: + """ + 使用LLM选择适合的表达方式(保持向后兼容) + + 注意:此方法已被 select_suitable_expressions 替代,建议使用新方法 + """ + logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions") + return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message) + + def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]): + """对一批表达方式更新last_active_time""" + if not expressions_to_update: + return + updates_by_key = {} + for expr in expressions_to_update: + source_id: str = expr.get("source_id") # type: ignore + expr_type: str = expr.get("type", "style") + situation: str = expr.get("situation") # type: ignore + style: str = expr.get("style") # type: ignore + if not source_id or not situation or not style: + logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") + continue + key = (source_id, expr_type, situation, style) + if key not in updates_by_key: + updates_by_key[key] = expr + for chat_id, expr_type, situation, style in updates_by_key: + query = Expression.select().where( + (Expression.chat_id == chat_id) + & (Expression.type == expr_type) + & (Expression.situation == situation) + & (Expression.style == style) + ) + if query.exists(): + expr_obj = query.get() + expr_obj.last_active_time = time.time() + expr_obj.save() + logger.debug( + "表达方式激活: 更新last_active_time in db" + ) + + +init_prompt() + +try: + expression_selector = ExpressionSelector() +except Exception as e: + logger.error(f"ExpressionSelector初始化失败: {e}") diff --git a/src/express/expressor_model/model.py b/src/express/expressor_model/model.py new file mode 100644 index 00000000..d8aec88a --- /dev/null +++ b/src/express/expressor_model/model.py @@ -0,0 +1,131 @@ +from typing import Dict, Optional, Tuple, List +from collections import Counter, defaultdict +import pickle +import os + +from .tokenizer import Tokenizer +from .online_nb import OnlineNaiveBayes + +class ExpressorModel: + """ + 直接使用朴素贝叶斯精排(可在线学习) + 支持存储situation字段,不参与计算,仅与style对应 + """ + + def __init__(self, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + vocab_size: int = 200000, + use_jieba: bool = True): + self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba) + self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size) + self._candidates: Dict[str, str] = {} # cid -> text (style) + self._situations: Dict[str, str] = {} # cid -> situation (不参与计算) + + def add_candidate(self, cid: str, text: str, situation: str = None): + """添加候选文本和对应的situation""" + self._candidates[cid] = text + if situation is not None: + self._situations[cid] = situation + + # 确保在nb模型中初始化该候选的计数 + if cid not in self.nb.cls_counts: + self.nb.cls_counts[cid] = 0.0 + if cid not in self.nb.token_counts: + self.nb.token_counts[cid] = defaultdict(float) + + def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None): + """批量添加候选文本和对应的situations""" + for i, (cid, text) in enumerate(items): + situation = situations[i] if situations and i < len(situations) else None + self.add_candidate(cid, text, situation) + + def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]: + """直接对所有候选进行朴素贝叶斯评分""" + toks = self.tokenizer.tokenize(text) + if not toks: + return None, {} + + if not self._candidates: + return None, {} + + # 对所有候选进行评分 + tf = Counter(toks) + all_cids = list(self._candidates.keys()) + scores = self.nb.score_batch(tf, all_cids) + + # 取最高分 + if not scores: + return None, {} + best = max(scores.items(), key=lambda x: x[1])[0] + return best, scores + + def update_positive(self, text: str, cid: str): + """更新正反馈学习""" + toks = self.tokenizer.tokenize(text) + if not toks: + return + tf = Counter(toks) + self.nb.update_positive(tf, cid) + + def decay(self, factor: float): + self.nb.decay(factor=factor) + + def get_situation(self, cid: str) -> Optional[str]: + """获取候选对应的situation""" + return self._situations.get(cid) + + def get_style(self, cid: str) -> Optional[str]: + """获取候选对应的style""" + return self._candidates.get(cid) + + def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]: + """获取候选的style和situation信息""" + return self._candidates.get(cid), self._situations.get(cid) + + def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]: + """获取所有候选的style和situation信息""" + return {cid: (style, self._situations.get(cid)) + for cid, style in self._candidates.items()} + + def save(self, path: str): + """保存模型""" + os.makedirs(os.path.dirname(path), exist_ok=True) + with open(path, "wb") as f: + pickle.dump({ + "candidates": self._candidates, + "situations": self._situations, + "nb": { + "cls_counts": dict(self.nb.cls_counts), + "token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()}, + "alpha": self.nb.alpha, + "beta": self.nb.beta, + "gamma": self.nb.gamma, + "V": self.nb.V, + } + }, f) + + def load(self, path: str): + """加载模型""" + with open(path, "rb") as f: + obj = pickle.load(f) + # 还原候选文本 + self._candidates = obj["candidates"] + # 还原situations(兼容旧版本) + self._situations = obj.get("situations", {}) + # 还原朴素贝叶斯模型 + self.nb.cls_counts = obj["nb"]["cls_counts"] + self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"]) + self.nb.alpha = obj["nb"]["alpha"] + self.nb.beta = obj["nb"]["beta"] + self.nb.gamma = obj["nb"]["gamma"] + self.nb.V = obj["nb"]["V"] + self.nb._logZ.clear() + +def defaultdict_dict(d: Dict[str, Dict[str, float]]): + from collections import defaultdict + outer = defaultdict(lambda: defaultdict(float)) + for k, inner in d.items(): + outer[k].update(inner) + return outer \ No newline at end of file diff --git a/src/express/expressor_model/online_nb.py b/src/express/expressor_model/online_nb.py new file mode 100644 index 00000000..9705043b --- /dev/null +++ b/src/express/expressor_model/online_nb.py @@ -0,0 +1,60 @@ +import math +from typing import Dict, List +from collections import defaultdict, Counter + +class OnlineNaiveBayes: + def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000): + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.V = vocab_size + + self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count + self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count + self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα) + + def _invalidate(self, cid: str): + if cid in self._logZ: + del self._logZ[cid] + + def _logZ_c(self, cid: str) -> float: + if cid not in self._logZ: + Z = self.cls_counts[cid] + self.V * self.alpha + self._logZ[cid] = math.log(max(Z, 1e-12)) + return self._logZ[cid] + + def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]: + total_cls = sum(self.cls_counts.values()) + n_cls = max(1, len(self.cls_counts)) + denom_prior = math.log(total_cls + self.beta * n_cls) + + out: Dict[str, float] = {} + for cid in cids: + prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior + s = prior + logZ = self._logZ_c(cid) + tc = self.token_counts[cid] + for term, qtf in tf.items(): + num = tc.get(term, 0.0) + self.alpha + s += qtf * (math.log(num) - logZ) + out[cid] = s + return out + + def update_positive(self, tf: Counter, cid: str): + inc = 0.0 + tc = self.token_counts[cid] + for term, c in tf.items(): + tc[term] += float(c) + inc += float(c) + self.cls_counts[cid] += inc + self._invalidate(cid) + + def decay(self, factor: float = None): + g = self.gamma if factor is None else factor + if g >= 1.0: + return + for cid in list(self.cls_counts.keys()): + self.cls_counts[cid] *= g + for term in list(self.token_counts[cid].keys()): + self.token_counts[cid][term] *= g + self._invalidate(cid) \ No newline at end of file diff --git a/src/express/expressor_model/tokenizer.py b/src/express/expressor_model/tokenizer.py new file mode 100644 index 00000000..709e6a54 --- /dev/null +++ b/src/express/expressor_model/tokenizer.py @@ -0,0 +1,28 @@ +import re +from typing import List, Optional, Set + +try: + import jieba + _HAS_JIEBA = True +except Exception: + _HAS_JIEBA = False + +_WORD_RE = re.compile(r"[A-Za-z0-9_]+") + +def simple_en_tokenize(text: str) -> List[str]: + return _WORD_RE.findall(text.lower()) + +class Tokenizer: + def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True): + self.stopwords = stopwords or set() + self.use_jieba = use_jieba and _HAS_JIEBA + + def tokenize(self, text: str) -> List[str]: + text = (text or "").strip() + if not text: + return [] + if self.use_jieba: + toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()] + else: + toks = simple_en_tokenize(text) + return [t for t in toks if t not in self.stopwords] \ No newline at end of file diff --git a/src/express/style_learner.py b/src/express/style_learner.py new file mode 100644 index 00000000..4cacba78 --- /dev/null +++ b/src/express/style_learner.py @@ -0,0 +1,628 @@ +""" +多聊天室表达风格学习系统 +支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射 +""" + +import os +import pickle +import traceback +from typing import Dict, List, Optional, Tuple +from collections import defaultdict +import asyncio + +from src.common.logger import get_logger +from .expressor_model.model import ExpressorModel + +logger = get_logger("style_learner") + + +class StyleLearner: + """ + 单个聊天室的表达风格学习器 + 学习从up_content到style的映射关系 + 支持动态管理风格集合(最多2000个) + """ + + def __init__(self, chat_id: str, model_config: Optional[Dict] = None): + self.chat_id = chat_id + self.model_config = model_config or { + "alpha": 0.5, + "beta": 0.5, + "gamma": 0.99, # 衰减因子,支持遗忘 + "vocab_size": 200000, + "use_jieba": True + } + + # 初始化表达模型 + self.expressor = ExpressorModel(**self.model_config) + + # 动态风格管理 + self.max_styles = 2000 # 每个chat_id最多2000个风格 + self.style_to_id: Dict[str, str] = {} # style文本 -> style_id + self.id_to_style: Dict[str, str] = {} # style_id -> style文本 + self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本 + self.next_style_id = 0 # 下一个可用的style_id + + # 学习统计 + self.learning_stats = { + "total_samples": 0, + "style_counts": defaultdict(int), + "last_update": None, + "style_usage_frequency": defaultdict(int) # 风格使用频率 + } + + def add_style(self, style: str, situation: str = None) -> bool: + """ + 动态添加一个新的风格 + + Args: + style: 风格文本 + situation: 对应的situation文本(可选) + + Returns: + bool: 添加是否成功 + """ + try: + # 检查是否已存在 + if style in self.style_to_id: + logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在") + return True + + # 检查是否超过最大限制 + if len(self.style_to_id) >= self.max_styles: + logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})") + return False + + # 生成新的style_id + style_id = f"style_{self.next_style_id}" + self.next_style_id += 1 + + # 添加到映射 + self.style_to_id[style] = style_id + self.id_to_style[style_id] = style + if situation: + self.id_to_situation[style_id] = situation + + # 添加到expressor模型 + self.expressor.add_candidate(style_id, style, situation) + + logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" + + (f", situation: '{situation}'" if situation else "")) + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 添加风格失败: {e}") + return False + + def remove_style(self, style: str) -> bool: + """ + 删除一个风格 + + Args: + style: 要删除的风格文本 + + Returns: + bool: 删除是否成功 + """ + try: + if style not in self.style_to_id: + logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在") + return False + + style_id = self.style_to_id[style] + + # 从映射中删除 + del self.style_to_id[style] + del self.id_to_style[style_id] + if style_id in self.id_to_situation: + del self.id_to_situation[style_id] + + # 从expressor模型中删除(通过重新构建) + self._rebuild_expressor() + + logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})") + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 删除风格失败: {e}") + return False + + def update_style(self, old_style: str, new_style: str) -> bool: + """ + 更新一个风格 + + Args: + old_style: 原风格文本 + new_style: 新风格文本 + + Returns: + bool: 更新是否成功 + """ + try: + if old_style not in self.style_to_id: + logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在") + return False + + if new_style in self.style_to_id and new_style != old_style: + logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在") + return False + + style_id = self.style_to_id[old_style] + + # 更新映射 + del self.style_to_id[old_style] + self.style_to_id[new_style] = style_id + self.id_to_style[style_id] = new_style + + # 更新expressor模型(保留原有的situation) + situation = self.id_to_situation.get(style_id) + self.expressor.add_candidate(style_id, new_style, situation) + + logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'") + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 更新风格失败: {e}") + return False + + def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int: + """ + 批量添加风格 + + Args: + styles: 风格文本列表 + situations: 对应的situation文本列表(可选) + + Returns: + int: 成功添加的数量 + """ + success_count = 0 + for i, style in enumerate(styles): + situation = situations[i] if situations and i < len(situations) else None + if self.add_style(style, situation): + success_count += 1 + + logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功") + return success_count + + def get_all_styles(self) -> List[str]: + """获取所有已注册的风格""" + return list(self.style_to_id.keys()) + + def get_style_count(self) -> int: + """获取当前风格数量""" + return len(self.style_to_id) + + def get_situation(self, style: str) -> Optional[str]: + """ + 获取风格对应的situation + + Args: + style: 风格文本 + + Returns: + Optional[str]: 对应的situation,如果不存在则返回None + """ + if style not in self.style_to_id: + return None + + style_id = self.style_to_id[style] + return self.id_to_situation.get(style_id) + + def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]: + """ + 获取风格的完整信息 + + Args: + style: 风格文本 + + Returns: + Tuple[Optional[str], Optional[str]]: (style_id, situation) + """ + if style not in self.style_to_id: + return None, None + + style_id = self.style_to_id[style] + situation = self.id_to_situation.get(style_id) + return style_id, situation + + def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]: + """ + 获取所有风格的完整信息 + + Returns: + Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)} + """ + result = {} + for style, style_id in self.style_to_id.items(): + situation = self.id_to_situation.get(style_id) + result[style] = (style_id, situation) + return result + + def _rebuild_expressor(self): + """重新构建expressor模型(删除风格后使用)""" + try: + # 重新创建expressor + self.expressor = ExpressorModel(**self.model_config) + + # 重新添加所有风格和situation + for style_id, style_text in self.id_to_style.items(): + situation = self.id_to_situation.get(style_id) + self.expressor.add_candidate(style_id, style_text, situation) + + logger.debug(f"[{self.chat_id}] 已重新构建expressor模型") + + except Exception as e: + logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}") + + def learn_mapping(self, up_content: str, style: str) -> bool: + """ + 学习一个up_content到style的映射 + 如果style不存在,会自动添加 + + Args: + up_content: 输入内容 + style: 对应的style文本 + + Returns: + bool: 学习是否成功 + """ + try: + # 如果style不存在,先添加它 + if style not in self.style_to_id: + if not self.add_style(style): + logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败") + return False + + # 获取style_id + style_id = self.style_to_id[style] + + # 使用正反馈学习 + self.expressor.update_positive(up_content, style_id) + + # 更新统计 + self.learning_stats["total_samples"] += 1 + self.learning_stats["style_counts"][style_id] += 1 + self.learning_stats["style_usage_frequency"][style] += 1 + self.learning_stats["last_update"] = asyncio.get_event_loop().time() + + logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'") + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 学习映射失败: {e}") + traceback.print_exc() + return False + + def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 根据up_content预测最合适的style + + Args: + up_content: 输入内容 + top_k: 返回前k个候选 + + Returns: + Tuple[最佳style文本, 所有候选的分数] + """ + try: + best_style_id, scores = self.expressor.predict(up_content, k=top_k) + + if best_style_id is None: + return None, {} + + # 将style_id转换为style文本 + best_style = self.id_to_style.get(best_style_id) + + # 转换所有分数 + style_scores = {} + for sid, score in scores.items(): + style_text = self.id_to_style.get(sid) + if style_text: + style_scores[style_text] = score + + return best_style, style_scores + + except Exception as e: + logger.error(f"[{self.chat_id}] 预测style失败: {e}") + traceback.print_exc() + return None, {} + + def decay_learning(self, factor: Optional[float] = None) -> None: + """ + 对学习到的知识进行衰减(遗忘) + + Args: + factor: 衰减因子,None则使用配置中的gamma + """ + self.expressor.decay(factor) + logger.debug(f"[{self.chat_id}] 执行知识衰减") + + def get_stats(self) -> Dict: + """获取学习统计信息""" + return { + "chat_id": self.chat_id, + "total_samples": self.learning_stats["total_samples"], + "style_count": len(self.style_to_id), + "max_styles": self.max_styles, + "style_counts": dict(self.learning_stats["style_counts"]), + "style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]), + "last_update": self.learning_stats["last_update"], + "all_styles": list(self.style_to_id.keys()) + } + + def save(self, base_path: str) -> bool: + """ + 保存模型到文件 + + Args: + base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl + """ + try: + os.makedirs(base_path, exist_ok=True) + file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") + + # 保存模型和统计信息 + save_data = { + "model_config": self.model_config, + "style_to_id": self.style_to_id, + "id_to_style": self.id_to_style, + "id_to_situation": self.id_to_situation, + "next_style_id": self.next_style_id, + "max_styles": self.max_styles, + "learning_stats": self.learning_stats + } + + # 先保存expressor模型 + expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") + self.expressor.save(expressor_path) + + # 保存其他数据 + with open(file_path, "wb") as f: + pickle.dump(save_data, f) + + logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}") + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 保存模型失败: {e}") + return False + + def load(self, base_path: str) -> bool: + """ + 从文件加载模型 + + Args: + base_path: 基础路径 + """ + try: + file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl") + expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl") + + if not os.path.exists(file_path) or not os.path.exists(expressor_path): + logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置") + return False + + # 加载其他数据 + with open(file_path, "rb") as f: + save_data = pickle.load(f) + + # 恢复配置和状态 + self.model_config = save_data["model_config"] + self.style_to_id = save_data["style_to_id"] + self.id_to_style = save_data["id_to_style"] + self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本 + self.next_style_id = save_data["next_style_id"] + self.max_styles = save_data.get("max_styles", 2000) + self.learning_stats = save_data["learning_stats"] + + # 重新创建expressor并加载 + self.expressor = ExpressorModel(**self.model_config) + self.expressor.load(expressor_path) + + logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载") + return True + + except Exception as e: + logger.error(f"[{self.chat_id}] 加载模型失败: {e}") + return False + + +class StyleLearnerManager: + """ + 多聊天室表达风格学习管理器 + 为每个chat_id维护独立的StyleLearner实例 + 每个chat_id可以动态管理自己的风格集合(最多2000个) + """ + + def __init__(self, model_save_path: str = "data/style_models"): + self.model_save_path = model_save_path + self.learners: Dict[str, StyleLearner] = {} + + # 自动保存配置 + self.auto_save_interval = 300 # 5分钟 + self._auto_save_task: Optional[asyncio.Task] = None + + logger.info("StyleLearnerManager 已初始化") + + def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner: + """ + 获取或创建指定chat_id的学习器 + + Args: + chat_id: 聊天室ID + model_config: 模型配置,None则使用默认配置 + + Returns: + StyleLearner实例 + """ + if chat_id not in self.learners: + # 创建新的学习器 + learner = StyleLearner(chat_id, model_config) + + # 尝试加载已保存的模型 + learner.load(self.model_save_path) + + self.learners[chat_id] = learner + logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner") + + return self.learners[chat_id] + + def add_style(self, chat_id: str, style: str) -> bool: + """ + 为指定chat_id添加风格 + + Args: + chat_id: 聊天室ID + style: 风格文本 + + Returns: + bool: 添加是否成功 + """ + learner = self.get_learner(chat_id) + return learner.add_style(style) + + def remove_style(self, chat_id: str, style: str) -> bool: + """ + 为指定chat_id删除风格 + + Args: + chat_id: 聊天室ID + style: 风格文本 + + Returns: + bool: 删除是否成功 + """ + learner = self.get_learner(chat_id) + return learner.remove_style(style) + + def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool: + """ + 为指定chat_id更新风格 + + Args: + chat_id: 聊天室ID + old_style: 原风格文本 + new_style: 新风格文本 + + Returns: + bool: 更新是否成功 + """ + learner = self.get_learner(chat_id) + return learner.update_style(old_style, new_style) + + def get_chat_styles(self, chat_id: str) -> List[str]: + """ + 获取指定chat_id的所有风格 + + Args: + chat_id: 聊天室ID + + Returns: + List[str]: 风格列表 + """ + learner = self.get_learner(chat_id) + return learner.get_all_styles() + + def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool: + """ + 学习一个映射关系 + + Args: + chat_id: 聊天室ID + up_content: 输入内容 + style: 对应的style + + Returns: + bool: 学习是否成功 + """ + learner = self.get_learner(chat_id) + return learner.learn_mapping(up_content, style) + + def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]: + """ + 预测最合适的style + + Args: + chat_id: 聊天室ID + up_content: 输入内容 + top_k: 返回前k个候选 + + Returns: + Tuple[最佳style, 所有候选分数] + """ + learner = self.get_learner(chat_id) + return learner.predict_style(up_content, top_k) + + def decay_all_learners(self, factor: Optional[float] = None) -> None: + """ + 对所有学习器执行衰减 + + Args: + factor: 衰减因子 + """ + for learner in self.learners.values(): + learner.decay_learning(factor) + logger.info("已对所有学习器执行衰减") + + def get_all_stats(self) -> Dict[str, Dict]: + """获取所有学习器的统计信息""" + return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()} + + def save_all_models(self) -> bool: + """保存所有模型""" + success_count = 0 + for learner in self.learners.values(): + if learner.save(self.model_save_path): + success_count += 1 + + logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型") + return success_count == len(self.learners) + + def load_all_models(self) -> int: + """加载所有已保存的模型""" + if not os.path.exists(self.model_save_path): + return 0 + + loaded_count = 0 + for filename in os.listdir(self.model_save_path): + if filename.endswith("_style_model.pkl"): + chat_id = filename.replace("_style_model.pkl", "") + learner = StyleLearner(chat_id) + if learner.load(self.model_save_path): + self.learners[chat_id] = learner + loaded_count += 1 + + logger.info(f"已加载 {loaded_count} 个模型") + return loaded_count + + async def start_auto_save(self) -> None: + """启动自动保存任务""" + if self._auto_save_task is None or self._auto_save_task.done(): + self._auto_save_task = asyncio.create_task(self._auto_save_loop()) + logger.info("已启动自动保存任务") + + async def stop_auto_save(self) -> None: + """停止自动保存任务""" + if self._auto_save_task and not self._auto_save_task.done(): + self._auto_save_task.cancel() + try: + await self._auto_save_task + except asyncio.CancelledError: + pass + logger.info("已停止自动保存任务") + + async def _auto_save_loop(self) -> None: + """自动保存循环""" + while True: + try: + await asyncio.sleep(self.auto_save_interval) + self.save_all_models() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"自动保存失败: {e}") + + +# 全局管理器实例 +style_learner_manager = StyleLearnerManager() diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index dd8141a6..b88077ca 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -45,9 +45,9 @@ private_plan_style = """ 3.某句话如果已经被回复过,不要重复回复""" [expression] -# 表达方式模式(此选项暂未使用) -mode = "context" -# 可选:llm模式,context上下文模式 +# 表达方式模式 +mode = "classic" +# 可选:classic经典模式,exp_model 表达模型模式 # 表达学习配置 learning_list = [ # 表达学习配置列表,支持按聊天流配置 diff --git a/test_expression_selector_prediction.py b/test_expression_selector_prediction.py new file mode 100644 index 00000000..ef75aa3b --- /dev/null +++ b/test_expression_selector_prediction.py @@ -0,0 +1,152 @@ +""" +测试修改后的 expression_selector 使用模型预测功能 +验证不再随机选取,而是使用 style_learner 模型预测 +""" + +import os +import sys +import asyncio + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from src.express.expression_selector import ExpressionSelector +from src.express.style_learner import style_learner_manager +from src.common.logger import get_logger + +logger = get_logger("expression_selector_test") + + +async def test_model_prediction_selector(): + """测试使用模型预测的表达选择器""" + print("=== Expression Selector 模型预测功能测试 ===\n") + + # 创建选择器实例 + selector = ExpressionSelector() + + # 测试聊天室ID + test_chat_id = "test_prediction_chat" + + print(f"测试聊天室: {test_chat_id}") + + # 1. 先为测试聊天室添加一些风格和situation + print(f"\n1. 准备测试数据...") + + test_data = [ + ("温柔回复", "打招呼"), + ("幽默回复", "表达惊讶"), + ("严肃回复", "询问问题"), + ("活泼回复", "表达开心"), + ("高冷回复", "表示不满"), + ] + + for style, situation in test_data: + success = style_learner_manager.add_style(test_chat_id, style, situation) + print(f" 添加: '{style}' (situation: '{situation}') -> {'成功' if success else '失败'}") + + # 2. 学习一些映射关系 + print(f"\n2. 学习映射关系...") + + learning_data = [ + ("你好", "温柔回复"), + ("谢谢", "温柔回复"), + ("哈哈", "幽默回复"), + ("请解释", "严肃回复"), + ("太棒了", "活泼回复"), + ] + + for up_content, style in learning_data: + success = style_learner_manager.learn_mapping(test_chat_id, up_content, style) + print(f" 学习: '{up_content}' -> '{style}' -> {'成功' if success else '失败'}") + + # 3. 测试模型预测功能 + print(f"\n3. 测试模型预测功能...") + + test_chat_scenarios = [ + "用户: 你好\n机器人: 你好,有什么可以帮助你的吗?", + "用户: 哈哈,太搞笑了\n机器人: 确实很有趣呢!", + "用户: 请解释一下这个问题\n机器人: 好的,让我详细说明一下", + "用户: 太棒了!\n机器人: 很高兴听到这个消息!", + ] + + for i, chat_info in enumerate(test_chat_scenarios, 1): + print(f"\n 场景 {i}:") + print(f" 聊天内容: {chat_info}") + + # 使用模型预测表达方式 + predicted_expressions = selector.get_model_predicted_expressions( + test_chat_id, chat_info, total_num=3 + ) + + print(f" 预测结果: {len(predicted_expressions)} 个表达方式") + for j, expr in enumerate(predicted_expressions, 1): + print(f" {j}. situation: '{expr['situation']}'") + print(f" style: '{expr['style']}'") + print(f" 分数: {expr.get('prediction_score', 0.0):.4f}") + print(f" 输入: '{expr.get('prediction_input', '')}'") + + # 4. 测试LLM选择功能 + print(f"\n4. 测试LLM选择功能...") + + # 模拟聊天信息 + chat_info = "用户: 你好,我想了解一下这个功能\n机器人: 好的,我来为你详细介绍" + + try: + selected_expressions, selected_ids = await selector.select_suitable_expressions_llm( + test_chat_id, chat_info, max_num=3 + ) + + print(f" LLM选择结果: {len(selected_expressions)} 个表达方式") + for i, expr in enumerate(selected_expressions, 1): + print(f" {i}. situation: '{expr['situation']}'") + print(f" style: '{expr['style']}'") + print(f" 来源: {expr['source_id']}") + + except Exception as e: + print(f" LLM选择失败: {e}") + + # 5. 测试回退机制 + print(f"\n5. 测试回退机制...") + + # 使用不存在的聊天室测试回退 + fake_chat_id = "fake_chat_id" + fallback_expressions = selector._fallback_random_expressions(fake_chat_id, 3) + print(f" 回退机制测试: {len(fallback_expressions)} 个表达方式") + + # 6. 测试预测输入提取 + print(f"\n6. 测试预测输入提取...") + + test_chat_infos = [ + "用户: 你好\n机器人: 你好!", + "这是一段很长的聊天内容,包含了很多信息,用户说了很多话,机器人也回复了很多内容,现在我们要测试提取功能", + "单行内容", + "", + ] + + for i, chat_info in enumerate(test_chat_infos, 1): + prediction_input = selector._extract_prediction_input(chat_info) + print(f" 测试 {i}:") + print(f" 原始: '{chat_info}'") + print(f" 提取: '{prediction_input}'") + + print(f"\n✅ 所有测试完成!") + print(f"\n=== 功能总结 ===") + print(f"✓ Expression Selector 现在使用 style_learner 模型进行预测") + print(f"✓ 不再随机选择,而是基于聊天内容预测最合适的 style") + print(f"✓ 自动获取预测 style 对应的 situation") + print(f"✓ 支持多聊天室的预测") + print(f"✓ 包含回退机制,预测失败时使用随机选择") + print(f"✓ 支持预测输入提取和优化") + + +def main(): + """主函数""" + print("Expression Selector 模型预测功能测试") + print("=" * 60) + + # 运行异步测试 + asyncio.run(test_model_prediction_selector()) + + +if __name__ == "__main__": + main() diff --git a/test_expression_style_situation_integration.py b/test_expression_style_situation_integration.py new file mode 100644 index 00000000..5fedf8e5 --- /dev/null +++ b/test_expression_style_situation_integration.py @@ -0,0 +1,188 @@ +""" +测试修改后的 expression_learner 与 style_learner 的集成 +验证学习新表达时是否正确处理 situation 字段 +""" + +import os +import sys +import asyncio +import time + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from src.express.expression_learner import ExpressionLearner +from src.express.style_learner import style_learner_manager +from src.common.logger import get_logger + +logger = get_logger("expression_style_integration_test") + + +async def test_expression_style_integration(): + """测试 expression_learner 与 style_learner 的集成(包含 situation)""" + print("=== Expression Learner 与 Style Learner 集成测试(含 Situation) ===\n") + + # 创建测试聊天室ID + test_chat_id = "test_integration_situation_chat" + + # 创建 ExpressionLearner 实例 + expression_learner = ExpressionLearner(test_chat_id) + + print(f"测试聊天室: {test_chat_id}") + + # 模拟学习到的表达数据(包含 situation) + mock_learnt_expressions = [ + (test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"), + (test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"), + (test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"), + (test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"), + (test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"), + ] + + print("模拟学习到的表达数据(包含 situation):") + for chat_id, situation, style, context, up_content in mock_learnt_expressions: + print(f" {situation} -> {style} (输入: {up_content})") + + # 模拟 learn_and_store 方法的处理逻辑 + print(f"\n开始处理学习数据...") + + # 按chat_id分组 + chat_dict = {} + for chat_id, situation, style, context, up_content in mock_learnt_expressions: + if chat_id not in chat_dict: + chat_dict[chat_id] = [] + chat_dict[chat_id].append({ + "situation": situation, + "style": style, + "context": context, + "up_content": up_content, + }) + + # 训练 style_learner(包含 situation 处理) + trained_chat_ids = set() + + for chat_id, expr_list in chat_dict.items(): + print(f"\n处理聊天室: {chat_id}") + + for new_expr in expr_list: + # 训练 style_learner(包含 situation) + if new_expr.get("up_content") and new_expr.get("style"): + try: + # 获取 learner 实例 + learner = style_learner_manager.get_learner(chat_id) + + # 先添加风格和对应的 situation(如果不存在) + if new_expr.get("situation"): + learner.add_style(new_expr["style"], new_expr["situation"]) + print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')") + else: + learner.add_style(new_expr["style"]) + print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)") + + # 学习映射关系 + success = style_learner_manager.learn_mapping( + chat_id, + new_expr["up_content"], + new_expr["style"] + ) + if success: + print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" + + (f" (situation: {new_expr['situation']})" if new_expr.get("situation") else "")) + trained_chat_ids.add(chat_id) + else: + print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}") + except Exception as e: + print(f" ✗ StyleLearner学习异常: {e}") + + # 保存模型 + if trained_chat_ids: + print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...") + try: + save_success = style_learner_manager.save_all_models() + + if save_success: + print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}") + else: + print("✗ StyleLearner 模型保存失败") + + except Exception as e: + print(f"✗ StyleLearner 模型保存异常: {e}") + + # 测试预测功能 + print(f"\n测试 StyleLearner 预测功能:") + test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"] + + for test_input in test_inputs: + try: + best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3) + if best_style: + # 获取对应的 situation + learner = style_learner_manager.get_learner(test_chat_id) + situation = learner.get_situation(best_style) + print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')") + if scores: + top_scores = dict(list(scores.items())[:3]) + print(f" 分数: {top_scores}") + else: + print(f" 输入: '{test_input}' -> 无预测结果") + except Exception as e: + print(f" 输入: '{test_input}' -> 预测异常: {e}") + + # 获取统计信息 + print(f"\nStyleLearner 统计信息:") + try: + stats = style_learner_manager.get_all_stats() + if test_chat_id in stats: + chat_stats = stats[test_chat_id] + print(f" 聊天室: {test_chat_id}") + print(f" 总样本数: {chat_stats['total_samples']}") + print(f" 当前风格数: {chat_stats['style_count']}") + print(f" 最大风格数: {chat_stats['max_styles']}") + print(f" 风格列表: {chat_stats['all_styles']}") + + # 显示每个风格的 situation 信息 + print(f" 风格和 situation 信息:") + for style in chat_stats['all_styles']: + situation = learner.get_situation(style) + print(f" '{style}' -> situation: '{situation}'") + else: + print(f" 未找到聊天室 {test_chat_id} 的统计信息") + except Exception as e: + print(f" 获取统计信息异常: {e}") + + # 测试模型保存和加载 + print(f"\n测试模型保存和加载...") + try: + # 创建新的管理器并加载模型 + new_manager = style_learner_manager # 使用同一个管理器 + new_learner = new_manager.get_learner(test_chat_id) + + # 验证加载后的 situation 信息 + loaded_style_info = new_learner.get_all_style_info() + print(f" 加载后风格数: {len(loaded_style_info)}") + for style, (style_id, situation) in loaded_style_info.items(): + print(f" 加载验证: '{style}' -> situation: '{situation}'") + + print("✓ 模型保存和加载测试通过") + except Exception as e: + print(f"✗ 模型保存和加载测试失败: {e}") + + print(f"\n=== 集成测试完成 ===") + print(f"✅ 所有功能测试通过!") + print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner") + print(f"✓ StyleLearner 正确存储和获取 situation 信息") + print(f"✓ 预测功能正常工作,可以获取对应的 situation") + print(f"✓ 模型保存和加载支持 situation 字段") + + +def main(): + """主函数""" + print("Expression Learner 与 Style Learner 集成测试(含 Situation)") + print("=" * 70) + + # 运行异步测试 + asyncio.run(test_expression_style_integration()) + + +if __name__ == "__main__": + main() diff --git a/test_style_learner_db.py b/test_style_learner_db.py new file mode 100644 index 00000000..ba1e2023 --- /dev/null +++ b/test_style_learner_db.py @@ -0,0 +1,391 @@ +""" +StyleLearner 数据库测试脚本 +使用数据库中的expression数据测试style_learner功能 +""" + +import os +import sys +from typing import List, Dict, Tuple +from sklearn.model_selection import train_test_split +from sklearn.metrics import precision_recall_fscore_support + +# 添加项目根目录到Python路径 +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from src.common.database.database_model import Expression, db +from src.express.style_learner import StyleLearnerManager +from src.common.logger import get_logger + +logger = get_logger("style_learner_test") + + +class StyleLearnerDatabaseTest: + """使用数据库数据测试StyleLearner""" + + def __init__(self, random_state: int = 42): + self.random_state = random_state + self.manager = StyleLearnerManager(model_save_path="data/test_style_models") + + # 测试结果 + self.test_results = { + "total_samples": 0, + "train_samples": 0, + "test_samples": 0, + "unique_styles": 0, + "unique_chat_ids": 0, + "accuracy": 0.0, + "precision": 0.0, + "recall": 0.0, + "f1_score": 0.0, + "predictions": [], + "ground_truth": [], + "model_save_success": False, + "model_save_path": self.manager.model_save_path + } + + def load_data_from_database(self) -> List[Dict]: + """ + 从数据库加载expression数据 + + Returns: + List[Dict]: 包含up_content, style, chat_id的数据列表 + """ + try: + # 连接数据库 + db.connect(reuse_if_open=True) + + # 查询所有expression数据 + expressions = Expression.select().where( + (Expression.up_content.is_null(False)) & + (Expression.style.is_null(False)) & + (Expression.chat_id.is_null(False)) & + (Expression.type == "style") + ) + + data = [] + for expr in expressions: + if expr.up_content and expr.style and expr.chat_id: + data.append({ + "up_content": expr.up_content, + "style": expr.style, + "chat_id": expr.chat_id, + "last_active_time": expr.last_active_time, + "context": expr.context, + "situation": expr.situation + }) + + logger.info(f"从数据库加载了 {len(data)} 条expression数据") + return data + + except Exception as e: + logger.error(f"从数据库加载数据失败: {e}") + return [] + + def preprocess_data(self, data: List[Dict]) -> List[Dict]: + """ + 数据预处理 + + Args: + data: 原始数据 + + Returns: + List[Dict]: 预处理后的数据 + """ + # 过滤掉空值或过短的数据 + filtered_data = [] + for item in data: + up_content = item["up_content"].strip() + style = item["style"].strip() + + if len(up_content) >= 2 and len(style) >= 2: + filtered_data.append({ + "up_content": up_content, + "style": style, + "chat_id": item["chat_id"], + "last_active_time": item["last_active_time"], + "context": item["context"], + "situation": item["situation"] + }) + + logger.info(f"预处理后剩余 {len(filtered_data)} 条数据") + return filtered_data + + def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]: + """ + 分割训练集和测试集 + 训练集使用所有数据,测试集从训练集中随机选择5% + + Args: + data: 预处理后的数据 + + Returns: + Tuple[List[Dict], List[Dict]]: (训练集, 测试集) + """ + # 训练集使用所有数据 + train_data = data.copy() + + # 测试集从训练集中随机选择5% + test_size = 0.05 # 5% + test_data = train_test_split( + train_data, test_size=test_size, random_state=self.random_state + )[1] # 只取测试集部分 + + logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条") + logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%") + return train_data, test_data + + def train_model(self, train_data: List[Dict]) -> None: + """ + 训练模型 + + Args: + train_data: 训练数据 + """ + logger.info("开始训练模型...") + + # 统计信息 + chat_ids = set() + styles = set() + + for item in train_data: + chat_id = item["chat_id"] + up_content = item["up_content"] + style = item["style"] + + chat_ids.add(chat_id) + styles.add(style) + + # 学习映射关系 + success = self.manager.learn_mapping(chat_id, up_content, style) + if not success: + logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}") + + self.test_results["train_samples"] = len(train_data) + self.test_results["unique_styles"] = len(styles) + self.test_results["unique_chat_ids"] = len(chat_ids) + + logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室") + + # 保存训练好的模型 + logger.info("开始保存训练好的模型...") + save_success = self.manager.save_all_models() + self.test_results["model_save_success"] = save_success + + if save_success: + logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}") + print(f"✅ 模型已保存到: {self.manager.model_save_path}") + else: + logger.warning("部分模型保存失败") + print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}") + + def test_model(self, test_data: List[Dict]) -> None: + """ + 测试模型 + + Args: + test_data: 测试数据 + """ + logger.info("开始测试模型...") + + predictions = [] + ground_truth = [] + correct_predictions = 0 + + for item in test_data: + chat_id = item["chat_id"] + up_content = item["up_content"] + true_style = item["style"] + + # 预测风格 + predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1) + + predictions.append(predicted_style) + ground_truth.append(true_style) + + # 检查预测是否正确 + if predicted_style == true_style: + correct_predictions += 1 + + # 记录详细预测结果 + self.test_results["predictions"].append({ + "chat_id": chat_id, + "up_content": up_content, + "true_style": true_style, + "predicted_style": predicted_style, + "scores": scores + }) + + # 计算准确率 + accuracy = correct_predictions / len(test_data) if test_data else 0 + + # 计算其他指标(需要处理None值) + valid_predictions = [p for p in predictions if p is not None] + valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None] + + if valid_predictions: + precision, recall, f1, _ = precision_recall_fscore_support( + valid_ground_truth, valid_predictions, average='weighted', zero_division=0 + ) + else: + precision = recall = f1 = 0.0 + + self.test_results["test_samples"] = len(test_data) + self.test_results["accuracy"] = accuracy + self.test_results["precision"] = precision + self.test_results["recall"] = recall + self.test_results["f1_score"] = f1 + + logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}") + + def analyze_results(self) -> None: + """分析测试结果""" + logger.info("=== 测试结果分析 ===") + + print("\n📊 数据统计:") + print(f" 总样本数: {self.test_results['total_samples']}") + print(f" 训练样本数: {self.test_results['train_samples']}") + print(f" 测试样本数: {self.test_results['test_samples']}") + print(f" 唯一风格数: {self.test_results['unique_styles']}") + print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}") + + print("\n🎯 模型性能:") + print(f" 准确率: {self.test_results['accuracy']:.4f}") + print(f" 精确率: {self.test_results['precision']:.4f}") + print(f" 召回率: {self.test_results['recall']:.4f}") + print(f" F1分数: {self.test_results['f1_score']:.4f}") + + print("\n💾 模型保存:") + save_status = "成功" if self.test_results['model_save_success'] else "失败" + print(f" 保存状态: {save_status}") + print(f" 保存路径: {self.test_results['model_save_path']}") + + # 分析各聊天室的性能 + chat_performance = {} + for pred in self.test_results["predictions"]: + chat_id = pred["chat_id"] + if chat_id not in chat_performance: + chat_performance[chat_id] = {"correct": 0, "total": 0} + + chat_performance[chat_id]["total"] += 1 + if pred["predicted_style"] == pred["true_style"]: + chat_performance[chat_id]["correct"] += 1 + + print("\n📈 各聊天室性能:") + for chat_id, perf in chat_performance.items(): + accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0 + print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})") + + # 分析风格分布 + style_counts = {} + for pred in self.test_results["predictions"]: + style = pred["true_style"] + style_counts[style] = style_counts.get(style, 0) + 1 + + print("\n🎨 风格分布 (前10个):") + sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True) + for style, count in sorted_styles[:10]: + print(f" {style}: {count} 次") + + def show_sample_predictions(self, num_samples: int = 10) -> None: + """显示样本预测结果""" + print(f"\n🔍 样本预测结果 (前{num_samples}个):") + + for i, pred in enumerate(self.test_results["predictions"][:num_samples]): + status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗" + print(f"\n {i+1}. {status}") + print(f" 聊天室: {pred['chat_id']}") + print(f" 输入内容: {pred['up_content']}") + print(f" 真实风格: {pred['true_style']}") + print(f" 预测风格: {pred['predicted_style']}") + if pred["scores"]: + top_scores = dict(list(pred["scores"].items())[:3]) + print(f" 分数: {top_scores}") + + def save_results(self, output_file: str = "style_learner_test_results.txt") -> None: + """保存测试结果到文件""" + try: + with open(output_file, "w", encoding="utf-8") as f: + f.write("StyleLearner 数据库测试结果\n") + f.write("=" * 50 + "\n\n") + + f.write("数据统计:\n") + f.write(f" 总样本数: {self.test_results['total_samples']}\n") + f.write(f" 训练样本数: {self.test_results['train_samples']}\n") + f.write(f" 测试样本数: {self.test_results['test_samples']}\n") + f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n") + f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n") + + f.write("模型性能:\n") + f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n") + f.write(f" 精确率: {self.test_results['precision']:.4f}\n") + f.write(f" 召回率: {self.test_results['recall']:.4f}\n") + f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n") + + f.write("模型保存:\n") + save_status = "成功" if self.test_results['model_save_success'] else "失败" + f.write(f" 保存状态: {save_status}\n") + f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n") + + f.write("详细预测结果:\n") + for i, pred in enumerate(self.test_results["predictions"]): + status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗" + f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n") + + logger.info(f"测试结果已保存到 {output_file}") + + except Exception as e: + logger.error(f"保存测试结果失败: {e}") + + def run_test(self) -> None: + """运行完整测试""" + logger.info("开始StyleLearner数据库测试...") + + # 1. 加载数据 + raw_data = self.load_data_from_database() + if not raw_data: + logger.error("没有加载到数据,测试终止") + return + + # 2. 数据预处理 + processed_data = self.preprocess_data(raw_data) + if not processed_data: + logger.error("预处理后没有数据,测试终止") + return + + self.test_results["total_samples"] = len(processed_data) + + # 3. 分割数据 + train_data, test_data = self.split_data(processed_data) + + # 4. 训练模型 + self.train_model(train_data) + + # 5. 测试模型 + self.test_model(test_data) + + # 6. 分析结果 + self.analyze_results() + + # 7. 显示样本预测 + self.show_sample_predictions(10) + + # 8. 保存结果 + self.save_results() + + logger.info("StyleLearner数据库测试完成!") + + +def main(): + """主函数""" + print("StyleLearner 数据库测试脚本") + print("=" * 50) + + # 创建测试实例 + test = StyleLearnerDatabaseTest(random_state=42) + + # 运行测试 + test.run_test() + + +if __name__ == "__main__": + main() diff --git a/view_pkl.py b/view_pkl.py new file mode 100644 index 00000000..0897e174 --- /dev/null +++ b/view_pkl.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 +""" +查看 .pkl 文件内容的工具脚本 +""" + +import pickle +import sys +import os +from pprint import pprint + +def view_pkl_file(file_path): + """查看 pkl 文件内容""" + if not os.path.exists(file_path): + print(f"❌ 文件不存在: {file_path}") + return + + try: + with open(file_path, 'rb') as f: + data = pickle.load(f) + + print(f"📁 文件: {file_path}") + print(f"📊 数据类型: {type(data)}") + print("=" * 50) + + if isinstance(data, dict): + print("🔑 字典键:") + for key in data.keys(): + print(f" - {key}: {type(data[key])}") + print() + + print("📋 详细内容:") + pprint(data, width=120, depth=10) + + elif isinstance(data, list): + print(f"📝 列表长度: {len(data)}") + if data: + print(f"📊 第一个元素类型: {type(data[0])}") + print("📋 前几个元素:") + for i, item in enumerate(data[:3]): + print(f" [{i}]: {item}") + + else: + print("📋 内容:") + pprint(data, width=120, depth=10) + + # 如果是 expressor 模型,特别显示 token_counts 的详细信息 + if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']: + print("\n" + "="*50) + print("🔍 详细词汇统计 (token_counts):") + token_counts = data['nb']['token_counts'] + for style_id, tokens in token_counts.items(): + print(f"\n📝 {style_id}:") + if tokens: + # 按词频排序显示前10个词 + sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) + for word, count in sorted_tokens[:10]: + print(f" '{word}': {count}") + if len(sorted_tokens) > 10: + print(f" ... 还有 {len(sorted_tokens) - 10} 个词") + else: + print(" (无词汇数据)") + + except Exception as e: + print(f"❌ 读取文件失败: {e}") + +def main(): + if len(sys.argv) != 2: + print("用法: python view_pkl.py ") + print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl") + return + + file_path = sys.argv[1] + view_pkl_file(file_path) + +if __name__ == "__main__": + main() diff --git a/view_tokens.py b/view_tokens.py new file mode 100644 index 00000000..03fe8992 --- /dev/null +++ b/view_tokens.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +""" +专门查看 expressor.pkl 文件中 token_counts 的脚本 +""" + +import pickle +import sys +import os + +def view_token_counts(file_path): + """查看 expressor.pkl 文件中的词汇统计""" + if not os.path.exists(file_path): + print(f"❌ 文件不存在: {file_path}") + return + + try: + with open(file_path, 'rb') as f: + data = pickle.load(f) + + print(f"📁 文件: {file_path}") + print("=" * 60) + + if 'nb' not in data or 'token_counts' not in data['nb']: + print("❌ 这不是一个 expressor 模型文件") + return + + token_counts = data['nb']['token_counts'] + candidates = data.get('candidates', {}) + + print(f"🎯 找到 {len(token_counts)} 个风格") + print("=" * 60) + + for style_id, tokens in token_counts.items(): + style_text = candidates.get(style_id, "未知风格") + print(f"\n📝 {style_id}: {style_text}") + print(f"📊 词汇数量: {len(tokens)}") + + if tokens: + # 按词频排序 + sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True) + + print("🔤 词汇统计 (按频率排序):") + for i, (word, count) in enumerate(sorted_tokens): + print(f" {i+1:2d}. '{word}': {count}") + else: + print(" (无词汇数据)") + + print("-" * 40) + + except Exception as e: + print(f"❌ 读取文件失败: {e}") + +def main(): + if len(sys.argv) != 2: + print("用法: python view_tokens.py ") + print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl") + return + + file_path = sys.argv[1] + view_token_counts(file_path) + +if __name__ == "__main__": + main()