feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -323,7 +323,7 @@ run_pet.bat
|
||||
!/plugins/emoji_manage_plugin
|
||||
!/plugins/take_picture_plugin
|
||||
!/plugins/deep_think
|
||||
!/plugins/MaiFrequencyControl
|
||||
!/plugins/BetterFrequency
|
||||
!/plugins/__init__.py
|
||||
|
||||
config.toml
|
||||
|
||||
@@ -16,7 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.person_info.person_info import Person
|
||||
from src.plugin_system.base.component_types import EventType, ActionInfo
|
||||
from src.plugin_system.core import events_manager
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
|
||||
"""按权重随机抽样"""
|
||||
if not population or not weights or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用累积权重的方法进行加权抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
weights_copy = weights.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 选择一个元素
|
||||
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
weights_copy.pop(chosen_idx)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
# sourcery skip: extract-duplicate-method, move-assign
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"count": expr.count,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 按权重抽样(使用count作为权重)
|
||||
if style_exprs:
|
||||
style_weights = [expr.get("count", 1) for expr in style_exprs]
|
||||
selected_style = weighted_sample(style_exprs, style_weights, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
return selected_style
|
||||
|
||||
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
|
||||
"""对一批表达方式更新count值,按chat_id+type分组后一次性写入数据库"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
current_count = expr_obj.count
|
||||
new_count = min(current_count + increment, 5.0)
|
||||
expr_obj.count = new_count
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
|
||||
)
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
# sourcery skip: inline-variable, list-comprehension
|
||||
"""使用LLM选择适合的表达方式"""
|
||||
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 1. 获取20个随机表达方式(现在按权重抽取)
|
||||
style_exprs = self.get_random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
try:
|
||||
# start_time = time.time()
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
|
||||
|
||||
# logger.info(f"模型名称: {model_name}")
|
||||
# logger.info(f"LLM返回结果: {content}")
|
||||
# if reasoning_content:
|
||||
# logger.info(f"LLM推理: {reasoning_content}")
|
||||
# else:
|
||||
# logger.info(f"LLM推理: 无")
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,一次性更新count数
|
||||
if valid_expressions:
|
||||
self.update_expressions_count_batch(valid_expressions, 0.006)
|
||||
|
||||
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
@@ -18,7 +18,7 @@ from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.heart_flow.hfc_utils import CycleDetail
|
||||
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
|
||||
from src.chat.express.expression_learner import expression_learner_manager
|
||||
from src.express.expression_learner import expression_learner_manager
|
||||
from src.chat.frequency_control.frequency_control import frequency_control_manager
|
||||
from src.memory_system.question_maker import QuestionMaker
|
||||
from src.memory_system.questions import global_conflict_tracker
|
||||
@@ -331,9 +331,8 @@ class HeartFChatting:
|
||||
|
||||
|
||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
await global_memory_chest.build_running_content(chat_id=self.stream_id)
|
||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
|
||||
|
||||
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
|
||||
@@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
|
||||
# from src.memory_system.memory_activator import MemoryActivator
|
||||
@@ -238,8 +238,8 @@ class DefaultReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
replace_user_references,
|
||||
)
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.express.expression_selector import expression_selector
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
from src.mood.mood_manager import mood_manager
|
||||
|
||||
@@ -256,8 +256,8 @@ class PrivateReplyer:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# LLM模式:调用LLM选择5-10个,然后随机选5个
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
|
||||
# 根据配置模式选择表达方式:exp_model模式直接使用模型预测,classic模式使用LLM选择
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
|
||||
)
|
||||
|
||||
|
||||
@@ -303,11 +303,10 @@ class Expression(BaseModel):
|
||||
|
||||
situation = TextField()
|
||||
style = TextField()
|
||||
count = FloatField()
|
||||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
context_words = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
last_active_time = FloatField()
|
||||
chat_id = TextField(index=True)
|
||||
|
||||
@@ -310,8 +310,8 @@ class MemoryConfig(ConfigBase):
|
||||
class ExpressionConfig(ConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
mode: Literal["llm", "context", "full-context"] = "context"
|
||||
"""表达方式模式,可选:llm模式,context上下文模式,full-context 完整上下文嵌入模式"""
|
||||
mode: str = "classic"
|
||||
"""表达方式模式,可选:classic经典模式,exp_model 表达模型模式"""
|
||||
|
||||
learning_list: list[list] = field(default_factory=lambda: [])
|
||||
"""
|
||||
|
||||
@@ -2,10 +2,12 @@ import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
import traceback
|
||||
import difflib
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
@@ -17,16 +19,23 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
DECAY_DAYS = 15 # 30天衰减到0.01
|
||||
DECAY_MIN = 0.01 # 最小衰减值
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
@@ -173,63 +182,7 @@ class ExpressionLearner:
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
"""
|
||||
对数据库中的所有表达方式应用全局衰减
|
||||
"""
|
||||
try:
|
||||
# 获取所有表达方式
|
||||
all_expressions = Expression.select()
|
||||
|
||||
updated_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 计算时间差
|
||||
last_active = expr.last_active_time
|
||||
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 如果count太小,删除这个表达方式
|
||||
expr.delete_instance()
|
||||
deleted_count += 1
|
||||
else:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
expr.save()
|
||||
updated_count += 1
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"数据库全局衰减失败: {e}")
|
||||
|
||||
def calculate_decay_factor(self, time_diff_days: float) -> float:
|
||||
"""
|
||||
计算衰减值
|
||||
当时间差为0天时,衰减值为0(最近活跃的不衰减)
|
||||
当时间差为7天时,衰减值为0.002(中等衰减)
|
||||
当时间差为30天或更长时,衰减值为0.01(高衰减)
|
||||
使用二次函数进行曲线插值
|
||||
"""
|
||||
if time_diff_days <= 0:
|
||||
return 0.0 # 刚激活的表达式不衰减
|
||||
|
||||
if time_diff_days >= DECAY_DAYS:
|
||||
return 0.01 # 长时间未活跃的表达式大幅衰减
|
||||
|
||||
# 使用二次函数插值:在0-30天之间从0衰减到0.01
|
||||
# 使用简单的二次函数:y = a * x^2
|
||||
# 当x=30时,y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
|
||||
a = 0.01 / (DECAY_DAYS**2)
|
||||
decay = a * (time_diff_days**2)
|
||||
|
||||
return min(0.01, decay)
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
@@ -247,7 +200,7 @@ class ExpressionLearner:
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_context_words,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
@@ -260,7 +213,7 @@ class ExpressionLearner:
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
context_words,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
@@ -269,13 +222,15 @@ class ExpressionLearner:
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"context_words": context_words,
|
||||
"up_content": up_content,
|
||||
}
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
trained_chat_ids = set() # 记录已训练的聊天室
|
||||
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
@@ -292,32 +247,72 @@ class ExpressionLearner:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.context_words = new_expr["context_words"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.up_content = new_expr["up_content"]
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
context_words=new_expr["context_words"],
|
||||
up_content=new_expr["up_content"],
|
||||
)
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
# 获取 learner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 先添加风格和对应的 situation(如果存在)
|
||||
if new_expr.get("situation"):
|
||||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||||
else:
|
||||
learner.add_style(new_expr["style"])
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
chat_id,
|
||||
new_expr["up_content"],
|
||||
new_expr["style"]
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
|
||||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||||
trained_chat_ids.add(chat_id)
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
|
||||
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
Expression.select()
|
||||
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
.order_by(Expression.count.asc())
|
||||
)
|
||||
if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除count最小的多余表达方式
|
||||
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
expr.delete_instance()
|
||||
# exprs = list(
|
||||
# Expression.select()
|
||||
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
# .order_by(Expression.last_active_time.asc())
|
||||
# )
|
||||
# if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除最久未活跃的多余表达方式
|
||||
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
# expr.delete_instance()
|
||||
|
||||
# 保存训练好的 style_learner 模型
|
||||
if trained_chat_ids:
|
||||
try:
|
||||
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||||
save_success = style_learner_manager.save_all_models()
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||||
else:
|
||||
logger.warning("StyleLearner 模型保存失败")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
@@ -339,8 +334,8 @@ class ExpressionLearner:
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
print(f"match_expression_context_prompt: {prompt}")
|
||||
print(f"random_msg_match_str: {response}")
|
||||
# print(f"match_expression_context_prompt: {prompt}")
|
||||
# print(f"{response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
@@ -395,6 +390,8 @@ class ExpressionLearner:
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
print(f"match_responses: {match_responses}")
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
@@ -418,7 +415,7 @@ class ExpressionLearner:
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
|
||||
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -466,39 +463,52 @@ class ExpressionLearner:
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
|
||||
print(f"matched_expressions: {matched_expressions}")
|
||||
|
||||
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
|
||||
matched_expressions
|
||||
)
|
||||
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
|
||||
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
at_pattern = r"@<[^:<>]+:[^:<>]+>"
|
||||
for idx, msg in enumerate(random_msg):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = re.sub(pic_pattern, "[图片]", content)
|
||||
content = re.sub(reply_pattern, "回复[某人]", content)
|
||||
content = re.sub(at_pattern, "@[某人]", content)
|
||||
content = content.strip()
|
||||
if content:
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
split_matched_expressions_w_emb = []
|
||||
|
||||
for situation, style, context, context_words in split_matched_expressions:
|
||||
split_matched_expressions_w_emb.append(
|
||||
(self.chat_id, situation, style, context, context_words)
|
||||
)
|
||||
|
||||
return split_matched_expressions_w_emb
|
||||
|
||||
def split_expression_context(
|
||||
self, matched_expressions: List[Tuple[str, str, str]]
|
||||
) -> List[Tuple[str, str, str, List[str]]]:
|
||||
"""
|
||||
对matched_expressions中的context部分进行jieba分词
|
||||
|
||||
Args:
|
||||
matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context)
|
||||
|
||||
Returns:
|
||||
添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words)
|
||||
"""
|
||||
result = []
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
# 使用jieba进行分词
|
||||
context_words = list(jieba.cut(context))
|
||||
result.append((situation, style, context, context_words))
|
||||
# 在 bare_lines 中找到第一处相似度达到85%的行
|
||||
pos = None
|
||||
for i, (_, c) in enumerate(bare_lines):
|
||||
similarity = calculate_similarity(c, context)
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到或没有上一句,跳过该表达
|
||||
continue
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = (random_msg[prev_original_idx].processed_plain_text or "").strip()
|
||||
if not up_content:
|
||||
continue
|
||||
filtered_with_up.append((situation, style, context, up_content))
|
||||
|
||||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
results: List[Tuple[str, str, str, str]] = []
|
||||
for (situation, style, context, up_content) in filtered_with_up:
|
||||
results.append((self.chat_id, situation, style, context, up_content))
|
||||
|
||||
return results
|
||||
|
||||
return result
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
520
src/express/expression_selector.py
Normal file
520
src/express/expression_selector.py
Normal file
@@ -0,0 +1,520 @@
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """
|
||||
以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1. 聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2. 话题类型(日常、技术、游戏、情感等)
|
||||
3. 情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""随机抽样"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用随机抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 随机选择一个元素
|
||||
chosen_idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
chat_info: 聊天内容信息
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 预测的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 从聊天信息中提取关键内容作为预测输入
|
||||
# 这里可以进一步优化,提取更合适的预测输入
|
||||
prediction_input = self._extract_prediction_input(chat_info)
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
for related_chat_id in related_chat_ids:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, prediction_input, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
# 获取预测风格的完整信息
|
||||
learner = style_learner_manager.get_learner(related_chat_id)
|
||||
style_id, situation = learner.get_style_info(best_style)
|
||||
|
||||
if style_id and situation:
|
||||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.type == "style") &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
|
||||
if expr_query.exists():
|
||||
expr = expr_query.get()
|
||||
predicted_expressions.append({
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": prediction_input
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
continue
|
||||
|
||||
# 按预测分数排序,取前 total_num 个
|
||||
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
|
||||
selected_expressions = predicted_expressions[:total_num]
|
||||
|
||||
logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._fallback_random_expressions(chat_id, total_num)
|
||||
|
||||
def _extract_prediction_input(self, chat_info: str) -> str:
|
||||
"""
|
||||
从聊天信息中提取用于预测的关键内容
|
||||
|
||||
Args:
|
||||
chat_info: 聊天内容信息
|
||||
|
||||
Returns:
|
||||
str: 提取的预测输入
|
||||
"""
|
||||
try:
|
||||
# 简单的提取策略:取最后几句话作为预测输入
|
||||
lines = chat_info.strip().split('\n')
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# 取最后3行作为预测输入
|
||||
recent_lines = lines[-3:]
|
||||
prediction_input = ' '.join(recent_lines).strip()
|
||||
|
||||
# 如果内容太长,截取前100个字符
|
||||
if len(prediction_input) > 100:
|
||||
prediction_input = prediction_input[:100]
|
||||
|
||||
return prediction_input
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"提取预测输入失败: {e}")
|
||||
return ""
|
||||
|
||||
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
回退到随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
if style_exprs:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
根据配置模式选择适合的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 获取配置模式
|
||||
expression_mode = global_config.expression.mode
|
||||
|
||||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
else:
|
||||
logger.warning(f"未知的表达模式: {expression_mode},回退到classic模式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
exp_model模式:直接使用模型预测,不经过LLM
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
|
||||
|
||||
# if len(style_exprs) < 5:
|
||||
# logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
# return [], []
|
||||
|
||||
# 直接取前max_num个预测结果
|
||||
selected_expressions = style_exprs[:max_num]
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_expressions:
|
||||
self.update_expressions_last_active_time(selected_expressions)
|
||||
|
||||
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
|
||||
return selected_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"exp_model模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要回复消息:{target_message}"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_info,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
使用LLM选择适合的表达方式(保持向后兼容)
|
||||
|
||||
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
|
||||
"""
|
||||
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
|
||||
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug(
|
||||
"表达方式激活: 更新last_active_time in db"
|
||||
)
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
131
src/express/expressor_model/model.py
Normal file
131
src/express/expressor_model/model.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
from collections import Counter, defaultdict
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .online_nb import OnlineNaiveBayes
|
||||
|
||||
class ExpressorModel:
|
||||
"""
|
||||
直接使用朴素贝叶斯精排(可在线学习)
|
||||
支持存储situation字段,不参与计算,仅与style对应
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
alpha: float = 0.5,
|
||||
beta: float = 0.5,
|
||||
gamma: float = 1.0,
|
||||
vocab_size: int = 200000,
|
||||
use_jieba: bool = True):
|
||||
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
|
||||
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
|
||||
self._candidates: Dict[str, str] = {} # cid -> text (style)
|
||||
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
|
||||
|
||||
def add_candidate(self, cid: str, text: str, situation: str = None):
|
||||
"""添加候选文本和对应的situation"""
|
||||
self._candidates[cid] = text
|
||||
if situation is not None:
|
||||
self._situations[cid] = situation
|
||||
|
||||
# 确保在nb模型中初始化该候选的计数
|
||||
if cid not in self.nb.cls_counts:
|
||||
self.nb.cls_counts[cid] = 0.0
|
||||
if cid not in self.nb.token_counts:
|
||||
self.nb.token_counts[cid] = defaultdict(float)
|
||||
|
||||
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
|
||||
"""批量添加候选文本和对应的situations"""
|
||||
for i, (cid, text) in enumerate(items):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
self.add_candidate(cid, text, situation)
|
||||
|
||||
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""直接对所有候选进行朴素贝叶斯评分"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return None, {}
|
||||
|
||||
if not self._candidates:
|
||||
return None, {}
|
||||
|
||||
# 对所有候选进行评分
|
||||
tf = Counter(toks)
|
||||
all_cids = list(self._candidates.keys())
|
||||
scores = self.nb.score_batch(tf, all_cids)
|
||||
|
||||
# 取最高分
|
||||
if not scores:
|
||||
return None, {}
|
||||
best = max(scores.items(), key=lambda x: x[1])[0]
|
||||
return best, scores
|
||||
|
||||
def update_positive(self, text: str, cid: str):
|
||||
"""更新正反馈学习"""
|
||||
toks = self.tokenizer.tokenize(text)
|
||||
if not toks:
|
||||
return
|
||||
tf = Counter(toks)
|
||||
self.nb.update_positive(tf, cid)
|
||||
|
||||
def decay(self, factor: float):
|
||||
self.nb.decay(factor=factor)
|
||||
|
||||
def get_situation(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的situation"""
|
||||
return self._situations.get(cid)
|
||||
|
||||
def get_style(self, cid: str) -> Optional[str]:
|
||||
"""获取候选对应的style"""
|
||||
return self._candidates.get(cid)
|
||||
|
||||
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""获取候选的style和situation信息"""
|
||||
return self._candidates.get(cid), self._situations.get(cid)
|
||||
|
||||
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""获取所有候选的style和situation信息"""
|
||||
return {cid: (style, self._situations.get(cid))
|
||||
for cid, style in self._candidates.items()}
|
||||
|
||||
def save(self, path: str):
|
||||
"""保存模型"""
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump({
|
||||
"candidates": self._candidates,
|
||||
"situations": self._situations,
|
||||
"nb": {
|
||||
"cls_counts": dict(self.nb.cls_counts),
|
||||
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
|
||||
"alpha": self.nb.alpha,
|
||||
"beta": self.nb.beta,
|
||||
"gamma": self.nb.gamma,
|
||||
"V": self.nb.V,
|
||||
}
|
||||
}, f)
|
||||
|
||||
def load(self, path: str):
|
||||
"""加载模型"""
|
||||
with open(path, "rb") as f:
|
||||
obj = pickle.load(f)
|
||||
# 还原候选文本
|
||||
self._candidates = obj["candidates"]
|
||||
# 还原situations(兼容旧版本)
|
||||
self._situations = obj.get("situations", {})
|
||||
# 还原朴素贝叶斯模型
|
||||
self.nb.cls_counts = obj["nb"]["cls_counts"]
|
||||
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
|
||||
self.nb.alpha = obj["nb"]["alpha"]
|
||||
self.nb.beta = obj["nb"]["beta"]
|
||||
self.nb.gamma = obj["nb"]["gamma"]
|
||||
self.nb.V = obj["nb"]["V"]
|
||||
self.nb._logZ.clear()
|
||||
|
||||
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
|
||||
from collections import defaultdict
|
||||
outer = defaultdict(lambda: defaultdict(float))
|
||||
for k, inner in d.items():
|
||||
outer[k].update(inner)
|
||||
return outer
|
||||
60
src/express/expressor_model/online_nb.py
Normal file
60
src/express/expressor_model/online_nb.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import math
|
||||
from typing import Dict, List
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
class OnlineNaiveBayes:
|
||||
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.V = vocab_size
|
||||
|
||||
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
|
||||
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
|
||||
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
|
||||
|
||||
def _invalidate(self, cid: str):
|
||||
if cid in self._logZ:
|
||||
del self._logZ[cid]
|
||||
|
||||
def _logZ_c(self, cid: str) -> float:
|
||||
if cid not in self._logZ:
|
||||
Z = self.cls_counts[cid] + self.V * self.alpha
|
||||
self._logZ[cid] = math.log(max(Z, 1e-12))
|
||||
return self._logZ[cid]
|
||||
|
||||
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
|
||||
total_cls = sum(self.cls_counts.values())
|
||||
n_cls = max(1, len(self.cls_counts))
|
||||
denom_prior = math.log(total_cls + self.beta * n_cls)
|
||||
|
||||
out: Dict[str, float] = {}
|
||||
for cid in cids:
|
||||
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
|
||||
s = prior
|
||||
logZ = self._logZ_c(cid)
|
||||
tc = self.token_counts[cid]
|
||||
for term, qtf in tf.items():
|
||||
num = tc.get(term, 0.0) + self.alpha
|
||||
s += qtf * (math.log(num) - logZ)
|
||||
out[cid] = s
|
||||
return out
|
||||
|
||||
def update_positive(self, tf: Counter, cid: str):
|
||||
inc = 0.0
|
||||
tc = self.token_counts[cid]
|
||||
for term, c in tf.items():
|
||||
tc[term] += float(c)
|
||||
inc += float(c)
|
||||
self.cls_counts[cid] += inc
|
||||
self._invalidate(cid)
|
||||
|
||||
def decay(self, factor: float = None):
|
||||
g = self.gamma if factor is None else factor
|
||||
if g >= 1.0:
|
||||
return
|
||||
for cid in list(self.cls_counts.keys()):
|
||||
self.cls_counts[cid] *= g
|
||||
for term in list(self.token_counts[cid].keys()):
|
||||
self.token_counts[cid][term] *= g
|
||||
self._invalidate(cid)
|
||||
28
src/express/expressor_model/tokenizer.py
Normal file
28
src/express/expressor_model/tokenizer.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import re
|
||||
from typing import List, Optional, Set
|
||||
|
||||
try:
|
||||
import jieba
|
||||
_HAS_JIEBA = True
|
||||
except Exception:
|
||||
_HAS_JIEBA = False
|
||||
|
||||
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
|
||||
|
||||
def simple_en_tokenize(text: str) -> List[str]:
|
||||
return _WORD_RE.findall(text.lower())
|
||||
|
||||
class Tokenizer:
|
||||
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
|
||||
self.stopwords = stopwords or set()
|
||||
self.use_jieba = use_jieba and _HAS_JIEBA
|
||||
|
||||
def tokenize(self, text: str) -> List[str]:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
return []
|
||||
if self.use_jieba:
|
||||
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
|
||||
else:
|
||||
toks = simple_en_tokenize(text)
|
||||
return [t for t in toks if t not in self.stopwords]
|
||||
628
src/express/style_learner.py
Normal file
628
src/express/style_learner.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
多聊天室表达风格学习系统
|
||||
支持为每个chat_id维护独立的表达模型,学习从up_content到style的映射
|
||||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .expressor_model.model import ExpressorModel
|
||||
|
||||
logger = get_logger("style_learner")
|
||||
|
||||
|
||||
class StyleLearner:
|
||||
"""
|
||||
单个聊天室的表达风格学习器
|
||||
学习从up_content到style的映射关系
|
||||
支持动态管理风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
|
||||
self.chat_id = chat_id
|
||||
self.model_config = model_config or {
|
||||
"alpha": 0.5,
|
||||
"beta": 0.5,
|
||||
"gamma": 0.99, # 衰减因子,支持遗忘
|
||||
"vocab_size": 200000,
|
||||
"use_jieba": True
|
||||
}
|
||||
|
||||
# 初始化表达模型
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 动态风格管理
|
||||
self.max_styles = 2000 # 每个chat_id最多2000个风格
|
||||
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
|
||||
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
|
||||
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
|
||||
self.next_style_id = 0 # 下一个可用的style_id
|
||||
|
||||
# 学习统计
|
||||
self.learning_stats = {
|
||||
"total_samples": 0,
|
||||
"style_counts": defaultdict(int),
|
||||
"last_update": None,
|
||||
"style_usage_frequency": defaultdict(int) # 风格使用频率
|
||||
}
|
||||
|
||||
def add_style(self, style: str, situation: str = None) -> bool:
|
||||
"""
|
||||
动态添加一个新的风格
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
situation: 对应的situation文本(可选)
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
try:
|
||||
# 检查是否已存在
|
||||
if style in self.style_to_id:
|
||||
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
|
||||
return True
|
||||
|
||||
# 检查是否超过最大限制
|
||||
if len(self.style_to_id) >= self.max_styles:
|
||||
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
|
||||
return False
|
||||
|
||||
# 生成新的style_id
|
||||
style_id = f"style_{self.next_style_id}"
|
||||
self.next_style_id += 1
|
||||
|
||||
# 添加到映射
|
||||
self.style_to_id[style] = style_id
|
||||
self.id_to_style[style_id] = style
|
||||
if situation:
|
||||
self.id_to_situation[style_id] = situation
|
||||
|
||||
# 添加到expressor模型
|
||||
self.expressor.add_candidate(style_id, style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
|
||||
(f", situation: '{situation}'" if situation else ""))
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
|
||||
return False
|
||||
|
||||
def remove_style(self, style: str) -> bool:
|
||||
"""
|
||||
删除一个风格
|
||||
|
||||
Args:
|
||||
style: 要删除的风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
try:
|
||||
if style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 从映射中删除
|
||||
del self.style_to_id[style]
|
||||
del self.id_to_style[style_id]
|
||||
if style_id in self.id_to_situation:
|
||||
del self.id_to_situation[style_id]
|
||||
|
||||
# 从expressor模型中删除(通过重新构建)
|
||||
self._rebuild_expressor()
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
|
||||
return False
|
||||
|
||||
def update_style(self, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
更新一个风格
|
||||
|
||||
Args:
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
try:
|
||||
if old_style not in self.style_to_id:
|
||||
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
|
||||
return False
|
||||
|
||||
if new_style in self.style_to_id and new_style != old_style:
|
||||
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
|
||||
return False
|
||||
|
||||
style_id = self.style_to_id[old_style]
|
||||
|
||||
# 更新映射
|
||||
del self.style_to_id[old_style]
|
||||
self.style_to_id[new_style] = style_id
|
||||
self.id_to_style[style_id] = new_style
|
||||
|
||||
# 更新expressor模型(保留原有的situation)
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, new_style, situation)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
|
||||
return False
|
||||
|
||||
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
|
||||
"""
|
||||
批量添加风格
|
||||
|
||||
Args:
|
||||
styles: 风格文本列表
|
||||
situations: 对应的situation文本列表(可选)
|
||||
|
||||
Returns:
|
||||
int: 成功添加的数量
|
||||
"""
|
||||
success_count = 0
|
||||
for i, style in enumerate(styles):
|
||||
situation = situations[i] if situations and i < len(situations) else None
|
||||
if self.add_style(style, situation):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
|
||||
return success_count
|
||||
|
||||
def get_all_styles(self) -> List[str]:
|
||||
"""获取所有已注册的风格"""
|
||||
return list(self.style_to_id.keys())
|
||||
|
||||
def get_style_count(self) -> int:
|
||||
"""获取当前风格数量"""
|
||||
return len(self.style_to_id)
|
||||
|
||||
def get_situation(self, style: str) -> Optional[str]:
|
||||
"""
|
||||
获取风格对应的situation
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Optional[str]: 对应的situation,如果不存在则返回None
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
return self.id_to_situation.get(style_id)
|
||||
|
||||
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
获取风格的完整信息
|
||||
|
||||
Args:
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]: (style_id, situation)
|
||||
"""
|
||||
if style not in self.style_to_id:
|
||||
return None, None
|
||||
|
||||
style_id = self.style_to_id[style]
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
return style_id, situation
|
||||
|
||||
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
|
||||
"""
|
||||
获取所有风格的完整信息
|
||||
|
||||
Returns:
|
||||
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
|
||||
"""
|
||||
result = {}
|
||||
for style, style_id in self.style_to_id.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
result[style] = (style_id, situation)
|
||||
return result
|
||||
|
||||
def _rebuild_expressor(self):
|
||||
"""重新构建expressor模型(删除风格后使用)"""
|
||||
try:
|
||||
# 重新创建expressor
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
|
||||
# 重新添加所有风格和situation
|
||||
for style_id, style_text in self.id_to_style.items():
|
||||
situation = self.id_to_situation.get(style_id)
|
||||
self.expressor.add_candidate(style_id, style_text, situation)
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
|
||||
|
||||
def learn_mapping(self, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个up_content到style的映射
|
||||
如果style不存在,会自动添加
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
style: 对应的style文本
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
try:
|
||||
# 如果style不存在,先添加它
|
||||
if style not in self.style_to_id:
|
||||
if not self.add_style(style):
|
||||
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
|
||||
return False
|
||||
|
||||
# 获取style_id
|
||||
style_id = self.style_to_id[style]
|
||||
|
||||
# 使用正反馈学习
|
||||
self.expressor.update_positive(up_content, style_id)
|
||||
|
||||
# 更新统计
|
||||
self.learning_stats["total_samples"] += 1
|
||||
self.learning_stats["style_counts"][style_id] += 1
|
||||
self.learning_stats["style_usage_frequency"][style] += 1
|
||||
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
|
||||
|
||||
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
根据up_content预测最合适的style
|
||||
|
||||
Args:
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style文本, 所有候选的分数]
|
||||
"""
|
||||
try:
|
||||
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
|
||||
|
||||
if best_style_id is None:
|
||||
return None, {}
|
||||
|
||||
# 将style_id转换为style文本
|
||||
best_style = self.id_to_style.get(best_style_id)
|
||||
|
||||
# 转换所有分数
|
||||
style_scores = {}
|
||||
for sid, score in scores.items():
|
||||
style_text = self.id_to_style.get(sid)
|
||||
if style_text:
|
||||
style_scores[style_text] = score
|
||||
|
||||
return best_style, style_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None, {}
|
||||
|
||||
def decay_learning(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对学习到的知识进行衰减(遗忘)
|
||||
|
||||
Args:
|
||||
factor: 衰减因子,None则使用配置中的gamma
|
||||
"""
|
||||
self.expressor.decay(factor)
|
||||
logger.debug(f"[{self.chat_id}] 执行知识衰减")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取学习统计信息"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"total_samples": self.learning_stats["total_samples"],
|
||||
"style_count": len(self.style_to_id),
|
||||
"max_styles": self.max_styles,
|
||||
"style_counts": dict(self.learning_stats["style_counts"]),
|
||||
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
|
||||
"last_update": self.learning_stats["last_update"],
|
||||
"all_styles": list(self.style_to_id.keys())
|
||||
}
|
||||
|
||||
def save(self, base_path: str) -> bool:
|
||||
"""
|
||||
保存模型到文件
|
||||
|
||||
Args:
|
||||
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
|
||||
"""
|
||||
try:
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
|
||||
# 保存模型和统计信息
|
||||
save_data = {
|
||||
"model_config": self.model_config,
|
||||
"style_to_id": self.style_to_id,
|
||||
"id_to_style": self.id_to_style,
|
||||
"id_to_situation": self.id_to_situation,
|
||||
"next_style_id": self.next_style_id,
|
||||
"max_styles": self.max_styles,
|
||||
"learning_stats": self.learning_stats
|
||||
}
|
||||
|
||||
# 先保存expressor模型
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
self.expressor.save(expressor_path)
|
||||
|
||||
# 保存其他数据
|
||||
with open(file_path, "wb") as f:
|
||||
pickle.dump(save_data, f)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
|
||||
return False
|
||||
|
||||
def load(self, base_path: str) -> bool:
|
||||
"""
|
||||
从文件加载模型
|
||||
|
||||
Args:
|
||||
base_path: 基础路径
|
||||
"""
|
||||
try:
|
||||
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
|
||||
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
|
||||
|
||||
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
|
||||
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
|
||||
return False
|
||||
|
||||
# 加载其他数据
|
||||
with open(file_path, "rb") as f:
|
||||
save_data = pickle.load(f)
|
||||
|
||||
# 恢复配置和状态
|
||||
self.model_config = save_data["model_config"]
|
||||
self.style_to_id = save_data["style_to_id"]
|
||||
self.id_to_style = save_data["id_to_style"]
|
||||
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
|
||||
self.next_style_id = save_data["next_style_id"]
|
||||
self.max_styles = save_data.get("max_styles", 2000)
|
||||
self.learning_stats = save_data["learning_stats"]
|
||||
|
||||
# 重新创建expressor并加载
|
||||
self.expressor = ExpressorModel(**self.model_config)
|
||||
self.expressor.load(expressor_path)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class StyleLearnerManager:
|
||||
"""
|
||||
多聊天室表达风格学习管理器
|
||||
为每个chat_id维护独立的StyleLearner实例
|
||||
每个chat_id可以动态管理自己的风格集合(最多2000个)
|
||||
"""
|
||||
|
||||
def __init__(self, model_save_path: str = "data/style_models"):
|
||||
self.model_save_path = model_save_path
|
||||
self.learners: Dict[str, StyleLearner] = {}
|
||||
|
||||
# 自动保存配置
|
||||
self.auto_save_interval = 300 # 5分钟
|
||||
self._auto_save_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("StyleLearnerManager 已初始化")
|
||||
|
||||
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
|
||||
"""
|
||||
获取或创建指定chat_id的学习器
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
model_config: 模型配置,None则使用默认配置
|
||||
|
||||
Returns:
|
||||
StyleLearner实例
|
||||
"""
|
||||
if chat_id not in self.learners:
|
||||
# 创建新的学习器
|
||||
learner = StyleLearner(chat_id, model_config)
|
||||
|
||||
# 尝试加载已保存的模型
|
||||
learner.load(self.model_save_path)
|
||||
|
||||
self.learners[chat_id] = learner
|
||||
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
|
||||
|
||||
return self.learners[chat_id]
|
||||
|
||||
def add_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id添加风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 添加是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.add_style(style)
|
||||
|
||||
def remove_style(self, chat_id: str, style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id删除风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
style: 风格文本
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.remove_style(style)
|
||||
|
||||
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
|
||||
"""
|
||||
为指定chat_id更新风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
old_style: 原风格文本
|
||||
new_style: 新风格文本
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.update_style(old_style, new_style)
|
||||
|
||||
def get_chat_styles(self, chat_id: str) -> List[str]:
|
||||
"""
|
||||
获取指定chat_id的所有风格
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
|
||||
Returns:
|
||||
List[str]: 风格列表
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.get_all_styles()
|
||||
|
||||
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
|
||||
"""
|
||||
学习一个映射关系
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
style: 对应的style
|
||||
|
||||
Returns:
|
||||
bool: 学习是否成功
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.learn_mapping(up_content, style)
|
||||
|
||||
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
|
||||
"""
|
||||
预测最合适的style
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
up_content: 输入内容
|
||||
top_k: 返回前k个候选
|
||||
|
||||
Returns:
|
||||
Tuple[最佳style, 所有候选分数]
|
||||
"""
|
||||
learner = self.get_learner(chat_id)
|
||||
return learner.predict_style(up_content, top_k)
|
||||
|
||||
def decay_all_learners(self, factor: Optional[float] = None) -> None:
|
||||
"""
|
||||
对所有学习器执行衰减
|
||||
|
||||
Args:
|
||||
factor: 衰减因子
|
||||
"""
|
||||
for learner in self.learners.values():
|
||||
learner.decay_learning(factor)
|
||||
logger.info("已对所有学习器执行衰减")
|
||||
|
||||
def get_all_stats(self) -> Dict[str, Dict]:
|
||||
"""获取所有学习器的统计信息"""
|
||||
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
|
||||
|
||||
def save_all_models(self) -> bool:
|
||||
"""保存所有模型"""
|
||||
success_count = 0
|
||||
for learner in self.learners.values():
|
||||
if learner.save(self.model_save_path):
|
||||
success_count += 1
|
||||
|
||||
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
|
||||
return success_count == len(self.learners)
|
||||
|
||||
def load_all_models(self) -> int:
|
||||
"""加载所有已保存的模型"""
|
||||
if not os.path.exists(self.model_save_path):
|
||||
return 0
|
||||
|
||||
loaded_count = 0
|
||||
for filename in os.listdir(self.model_save_path):
|
||||
if filename.endswith("_style_model.pkl"):
|
||||
chat_id = filename.replace("_style_model.pkl", "")
|
||||
learner = StyleLearner(chat_id)
|
||||
if learner.load(self.model_save_path):
|
||||
self.learners[chat_id] = learner
|
||||
loaded_count += 1
|
||||
|
||||
logger.info(f"已加载 {loaded_count} 个模型")
|
||||
return loaded_count
|
||||
|
||||
async def start_auto_save(self) -> None:
|
||||
"""启动自动保存任务"""
|
||||
if self._auto_save_task is None or self._auto_save_task.done():
|
||||
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
|
||||
logger.info("已启动自动保存任务")
|
||||
|
||||
async def stop_auto_save(self) -> None:
|
||||
"""停止自动保存任务"""
|
||||
if self._auto_save_task and not self._auto_save_task.done():
|
||||
self._auto_save_task.cancel()
|
||||
try:
|
||||
await self._auto_save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("已停止自动保存任务")
|
||||
|
||||
async def _auto_save_loop(self) -> None:
|
||||
"""自动保存循环"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.auto_save_interval)
|
||||
self.save_all_models()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"自动保存失败: {e}")
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
style_learner_manager = StyleLearnerManager()
|
||||
@@ -45,9 +45,9 @@ private_plan_style = """
|
||||
3.某句话如果已经被回复过,不要重复回复"""
|
||||
|
||||
[expression]
|
||||
# 表达方式模式(此选项暂未使用)
|
||||
mode = "context"
|
||||
# 可选:llm模式,context上下文模式
|
||||
# 表达方式模式
|
||||
mode = "classic"
|
||||
# 可选:classic经典模式,exp_model 表达模型模式
|
||||
|
||||
# 表达学习配置
|
||||
learning_list = [ # 表达学习配置列表,支持按聊天流配置
|
||||
|
||||
152
test_expression_selector_prediction.py
Normal file
152
test_expression_selector_prediction.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
测试修改后的 expression_selector 使用模型预测功能
|
||||
验证不再随机选取,而是使用 style_learner 模型预测
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.express.expression_selector import ExpressionSelector
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_selector_test")
|
||||
|
||||
|
||||
async def test_model_prediction_selector():
|
||||
"""测试使用模型预测的表达选择器"""
|
||||
print("=== Expression Selector 模型预测功能测试 ===\n")
|
||||
|
||||
# 创建选择器实例
|
||||
selector = ExpressionSelector()
|
||||
|
||||
# 测试聊天室ID
|
||||
test_chat_id = "test_prediction_chat"
|
||||
|
||||
print(f"测试聊天室: {test_chat_id}")
|
||||
|
||||
# 1. 先为测试聊天室添加一些风格和situation
|
||||
print(f"\n1. 准备测试数据...")
|
||||
|
||||
test_data = [
|
||||
("温柔回复", "打招呼"),
|
||||
("幽默回复", "表达惊讶"),
|
||||
("严肃回复", "询问问题"),
|
||||
("活泼回复", "表达开心"),
|
||||
("高冷回复", "表示不满"),
|
||||
]
|
||||
|
||||
for style, situation in test_data:
|
||||
success = style_learner_manager.add_style(test_chat_id, style, situation)
|
||||
print(f" 添加: '{style}' (situation: '{situation}') -> {'成功' if success else '失败'}")
|
||||
|
||||
# 2. 学习一些映射关系
|
||||
print(f"\n2. 学习映射关系...")
|
||||
|
||||
learning_data = [
|
||||
("你好", "温柔回复"),
|
||||
("谢谢", "温柔回复"),
|
||||
("哈哈", "幽默回复"),
|
||||
("请解释", "严肃回复"),
|
||||
("太棒了", "活泼回复"),
|
||||
]
|
||||
|
||||
for up_content, style in learning_data:
|
||||
success = style_learner_manager.learn_mapping(test_chat_id, up_content, style)
|
||||
print(f" 学习: '{up_content}' -> '{style}' -> {'成功' if success else '失败'}")
|
||||
|
||||
# 3. 测试模型预测功能
|
||||
print(f"\n3. 测试模型预测功能...")
|
||||
|
||||
test_chat_scenarios = [
|
||||
"用户: 你好\n机器人: 你好,有什么可以帮助你的吗?",
|
||||
"用户: 哈哈,太搞笑了\n机器人: 确实很有趣呢!",
|
||||
"用户: 请解释一下这个问题\n机器人: 好的,让我详细说明一下",
|
||||
"用户: 太棒了!\n机器人: 很高兴听到这个消息!",
|
||||
]
|
||||
|
||||
for i, chat_info in enumerate(test_chat_scenarios, 1):
|
||||
print(f"\n 场景 {i}:")
|
||||
print(f" 聊天内容: {chat_info}")
|
||||
|
||||
# 使用模型预测表达方式
|
||||
predicted_expressions = selector.get_model_predicted_expressions(
|
||||
test_chat_id, chat_info, total_num=3
|
||||
)
|
||||
|
||||
print(f" 预测结果: {len(predicted_expressions)} 个表达方式")
|
||||
for j, expr in enumerate(predicted_expressions, 1):
|
||||
print(f" {j}. situation: '{expr['situation']}'")
|
||||
print(f" style: '{expr['style']}'")
|
||||
print(f" 分数: {expr.get('prediction_score', 0.0):.4f}")
|
||||
print(f" 输入: '{expr.get('prediction_input', '')}'")
|
||||
|
||||
# 4. 测试LLM选择功能
|
||||
print(f"\n4. 测试LLM选择功能...")
|
||||
|
||||
# 模拟聊天信息
|
||||
chat_info = "用户: 你好,我想了解一下这个功能\n机器人: 好的,我来为你详细介绍"
|
||||
|
||||
try:
|
||||
selected_expressions, selected_ids = await selector.select_suitable_expressions_llm(
|
||||
test_chat_id, chat_info, max_num=3
|
||||
)
|
||||
|
||||
print(f" LLM选择结果: {len(selected_expressions)} 个表达方式")
|
||||
for i, expr in enumerate(selected_expressions, 1):
|
||||
print(f" {i}. situation: '{expr['situation']}'")
|
||||
print(f" style: '{expr['style']}'")
|
||||
print(f" 来源: {expr['source_id']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" LLM选择失败: {e}")
|
||||
|
||||
# 5. 测试回退机制
|
||||
print(f"\n5. 测试回退机制...")
|
||||
|
||||
# 使用不存在的聊天室测试回退
|
||||
fake_chat_id = "fake_chat_id"
|
||||
fallback_expressions = selector._fallback_random_expressions(fake_chat_id, 3)
|
||||
print(f" 回退机制测试: {len(fallback_expressions)} 个表达方式")
|
||||
|
||||
# 6. 测试预测输入提取
|
||||
print(f"\n6. 测试预测输入提取...")
|
||||
|
||||
test_chat_infos = [
|
||||
"用户: 你好\n机器人: 你好!",
|
||||
"这是一段很长的聊天内容,包含了很多信息,用户说了很多话,机器人也回复了很多内容,现在我们要测试提取功能",
|
||||
"单行内容",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, chat_info in enumerate(test_chat_infos, 1):
|
||||
prediction_input = selector._extract_prediction_input(chat_info)
|
||||
print(f" 测试 {i}:")
|
||||
print(f" 原始: '{chat_info}'")
|
||||
print(f" 提取: '{prediction_input}'")
|
||||
|
||||
print(f"\n✅ 所有测试完成!")
|
||||
print(f"\n=== 功能总结 ===")
|
||||
print(f"✓ Expression Selector 现在使用 style_learner 模型进行预测")
|
||||
print(f"✓ 不再随机选择,而是基于聊天内容预测最合适的 style")
|
||||
print(f"✓ 自动获取预测 style 对应的 situation")
|
||||
print(f"✓ 支持多聊天室的预测")
|
||||
print(f"✓ 包含回退机制,预测失败时使用随机选择")
|
||||
print(f"✓ 支持预测输入提取和优化")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("Expression Selector 模型预测功能测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 运行异步测试
|
||||
asyncio.run(test_model_prediction_selector())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
188
test_expression_style_situation_integration.py
Normal file
188
test_expression_style_situation_integration.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
测试修改后的 expression_learner 与 style_learner 的集成
|
||||
验证学习新表达时是否正确处理 situation 字段
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.express.expression_learner import ExpressionLearner
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_style_integration_test")
|
||||
|
||||
|
||||
async def test_expression_style_integration():
|
||||
"""测试 expression_learner 与 style_learner 的集成(包含 situation)"""
|
||||
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation) ===\n")
|
||||
|
||||
# 创建测试聊天室ID
|
||||
test_chat_id = "test_integration_situation_chat"
|
||||
|
||||
# 创建 ExpressionLearner 实例
|
||||
expression_learner = ExpressionLearner(test_chat_id)
|
||||
|
||||
print(f"测试聊天室: {test_chat_id}")
|
||||
|
||||
# 模拟学习到的表达数据(包含 situation)
|
||||
mock_learnt_expressions = [
|
||||
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
|
||||
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
|
||||
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
|
||||
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
|
||||
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
|
||||
]
|
||||
|
||||
print("模拟学习到的表达数据(包含 situation):")
|
||||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||||
print(f" {situation} -> {style} (输入: {up_content})")
|
||||
|
||||
# 模拟 learn_and_store 方法的处理逻辑
|
||||
print(f"\n开始处理学习数据...")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict = {}
|
||||
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"up_content": up_content,
|
||||
})
|
||||
|
||||
# 训练 style_learner(包含 situation 处理)
|
||||
trained_chat_ids = set()
|
||||
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
print(f"\n处理聊天室: {chat_id}")
|
||||
|
||||
for new_expr in expr_list:
|
||||
# 训练 style_learner(包含 situation)
|
||||
if new_expr.get("up_content") and new_expr.get("style"):
|
||||
try:
|
||||
# 获取 learner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 先添加风格和对应的 situation(如果不存在)
|
||||
if new_expr.get("situation"):
|
||||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||||
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
|
||||
else:
|
||||
learner.add_style(new_expr["style"])
|
||||
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
chat_id,
|
||||
new_expr["up_content"],
|
||||
new_expr["style"]
|
||||
)
|
||||
if success:
|
||||
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
|
||||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||||
trained_chat_ids.add(chat_id)
|
||||
else:
|
||||
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
|
||||
except Exception as e:
|
||||
print(f" ✗ StyleLearner学习异常: {e}")
|
||||
|
||||
# 保存模型
|
||||
if trained_chat_ids:
|
||||
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||||
try:
|
||||
save_success = style_learner_manager.save_all_models()
|
||||
|
||||
if save_success:
|
||||
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||||
else:
|
||||
print("✗ StyleLearner 模型保存失败")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ StyleLearner 模型保存异常: {e}")
|
||||
|
||||
# 测试预测功能
|
||||
print(f"\n测试 StyleLearner 预测功能:")
|
||||
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
|
||||
|
||||
for test_input in test_inputs:
|
||||
try:
|
||||
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
|
||||
if best_style:
|
||||
# 获取对应的 situation
|
||||
learner = style_learner_manager.get_learner(test_chat_id)
|
||||
situation = learner.get_situation(best_style)
|
||||
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
|
||||
if scores:
|
||||
top_scores = dict(list(scores.items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
else:
|
||||
print(f" 输入: '{test_input}' -> 无预测结果")
|
||||
except Exception as e:
|
||||
print(f" 输入: '{test_input}' -> 预测异常: {e}")
|
||||
|
||||
# 获取统计信息
|
||||
print(f"\nStyleLearner 统计信息:")
|
||||
try:
|
||||
stats = style_learner_manager.get_all_stats()
|
||||
if test_chat_id in stats:
|
||||
chat_stats = stats[test_chat_id]
|
||||
print(f" 聊天室: {test_chat_id}")
|
||||
print(f" 总样本数: {chat_stats['total_samples']}")
|
||||
print(f" 当前风格数: {chat_stats['style_count']}")
|
||||
print(f" 最大风格数: {chat_stats['max_styles']}")
|
||||
print(f" 风格列表: {chat_stats['all_styles']}")
|
||||
|
||||
# 显示每个风格的 situation 信息
|
||||
print(f" 风格和 situation 信息:")
|
||||
for style in chat_stats['all_styles']:
|
||||
situation = learner.get_situation(style)
|
||||
print(f" '{style}' -> situation: '{situation}'")
|
||||
else:
|
||||
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
|
||||
except Exception as e:
|
||||
print(f" 获取统计信息异常: {e}")
|
||||
|
||||
# 测试模型保存和加载
|
||||
print(f"\n测试模型保存和加载...")
|
||||
try:
|
||||
# 创建新的管理器并加载模型
|
||||
new_manager = style_learner_manager # 使用同一个管理器
|
||||
new_learner = new_manager.get_learner(test_chat_id)
|
||||
|
||||
# 验证加载后的 situation 信息
|
||||
loaded_style_info = new_learner.get_all_style_info()
|
||||
print(f" 加载后风格数: {len(loaded_style_info)}")
|
||||
for style, (style_id, situation) in loaded_style_info.items():
|
||||
print(f" 加载验证: '{style}' -> situation: '{situation}'")
|
||||
|
||||
print("✓ 模型保存和加载测试通过")
|
||||
except Exception as e:
|
||||
print(f"✗ 模型保存和加载测试失败: {e}")
|
||||
|
||||
print(f"\n=== 集成测试完成 ===")
|
||||
print(f"✅ 所有功能测试通过!")
|
||||
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
|
||||
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
|
||||
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
|
||||
print(f"✓ 模型保存和加载支持 situation 字段")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("Expression Learner 与 Style Learner 集成测试(含 Situation)")
|
||||
print("=" * 70)
|
||||
|
||||
# 运行异步测试
|
||||
asyncio.run(test_expression_style_integration())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
391
test_style_learner_db.py
Normal file
391
test_style_learner_db.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
StyleLearner 数据库测试脚本
|
||||
使用数据库中的expression数据测试style_learner功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict, Tuple
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.common.database.database_model import Expression, db
|
||||
from src.express.style_learner import StyleLearnerManager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("style_learner_test")
|
||||
|
||||
|
||||
class StyleLearnerDatabaseTest:
|
||||
"""使用数据库数据测试StyleLearner"""
|
||||
|
||||
def __init__(self, random_state: int = 42):
|
||||
self.random_state = random_state
|
||||
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
|
||||
|
||||
# 测试结果
|
||||
self.test_results = {
|
||||
"total_samples": 0,
|
||||
"train_samples": 0,
|
||||
"test_samples": 0,
|
||||
"unique_styles": 0,
|
||||
"unique_chat_ids": 0,
|
||||
"accuracy": 0.0,
|
||||
"precision": 0.0,
|
||||
"recall": 0.0,
|
||||
"f1_score": 0.0,
|
||||
"predictions": [],
|
||||
"ground_truth": [],
|
||||
"model_save_success": False,
|
||||
"model_save_path": self.manager.model_save_path
|
||||
}
|
||||
|
||||
def load_data_from_database(self) -> List[Dict]:
|
||||
"""
|
||||
从数据库加载expression数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含up_content, style, chat_id的数据列表
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
# 查询所有expression数据
|
||||
expressions = Expression.select().where(
|
||||
(Expression.up_content.is_null(False)) &
|
||||
(Expression.style.is_null(False)) &
|
||||
(Expression.chat_id.is_null(False)) &
|
||||
(Expression.type == "style")
|
||||
)
|
||||
|
||||
data = []
|
||||
for expr in expressions:
|
||||
if expr.up_content and expr.style and expr.chat_id:
|
||||
data.append({
|
||||
"up_content": expr.up_content,
|
||||
"style": expr.style,
|
||||
"chat_id": expr.chat_id,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"context": expr.context,
|
||||
"situation": expr.situation
|
||||
})
|
||||
|
||||
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载数据失败: {e}")
|
||||
return []
|
||||
|
||||
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
数据预处理
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
|
||||
Returns:
|
||||
List[Dict]: 预处理后的数据
|
||||
"""
|
||||
# 过滤掉空值或过短的数据
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
up_content = item["up_content"].strip()
|
||||
style = item["style"].strip()
|
||||
|
||||
if len(up_content) >= 2 and len(style) >= 2:
|
||||
filtered_data.append({
|
||||
"up_content": up_content,
|
||||
"style": style,
|
||||
"chat_id": item["chat_id"],
|
||||
"last_active_time": item["last_active_time"],
|
||||
"context": item["context"],
|
||||
"situation": item["situation"]
|
||||
})
|
||||
|
||||
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
|
||||
return filtered_data
|
||||
|
||||
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""
|
||||
分割训练集和测试集
|
||||
训练集使用所有数据,测试集从训练集中随机选择5%
|
||||
|
||||
Args:
|
||||
data: 预处理后的数据
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
|
||||
"""
|
||||
# 训练集使用所有数据
|
||||
train_data = data.copy()
|
||||
|
||||
# 测试集从训练集中随机选择5%
|
||||
test_size = 0.05 # 5%
|
||||
test_data = train_test_split(
|
||||
train_data, test_size=test_size, random_state=self.random_state
|
||||
)[1] # 只取测试集部分
|
||||
|
||||
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)} 条")
|
||||
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
|
||||
return train_data, test_data
|
||||
|
||||
def train_model(self, train_data: List[Dict]) -> None:
|
||||
"""
|
||||
训练模型
|
||||
|
||||
Args:
|
||||
train_data: 训练数据
|
||||
"""
|
||||
logger.info("开始训练模型...")
|
||||
|
||||
# 统计信息
|
||||
chat_ids = set()
|
||||
styles = set()
|
||||
|
||||
for item in train_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
style = item["style"]
|
||||
|
||||
chat_ids.add(chat_id)
|
||||
styles.add(style)
|
||||
|
||||
# 学习映射关系
|
||||
success = self.manager.learn_mapping(chat_id, up_content, style)
|
||||
if not success:
|
||||
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
|
||||
|
||||
self.test_results["train_samples"] = len(train_data)
|
||||
self.test_results["unique_styles"] = len(styles)
|
||||
self.test_results["unique_chat_ids"] = len(chat_ids)
|
||||
|
||||
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
|
||||
|
||||
# 保存训练好的模型
|
||||
logger.info("开始保存训练好的模型...")
|
||||
save_success = self.manager.save_all_models()
|
||||
self.test_results["model_save_success"] = save_success
|
||||
|
||||
if save_success:
|
||||
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
|
||||
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
|
||||
else:
|
||||
logger.warning("部分模型保存失败")
|
||||
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
|
||||
|
||||
def test_model(self, test_data: List[Dict]) -> None:
|
||||
"""
|
||||
测试模型
|
||||
|
||||
Args:
|
||||
test_data: 测试数据
|
||||
"""
|
||||
logger.info("开始测试模型...")
|
||||
|
||||
predictions = []
|
||||
ground_truth = []
|
||||
correct_predictions = 0
|
||||
|
||||
for item in test_data:
|
||||
chat_id = item["chat_id"]
|
||||
up_content = item["up_content"]
|
||||
true_style = item["style"]
|
||||
|
||||
# 预测风格
|
||||
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
|
||||
|
||||
predictions.append(predicted_style)
|
||||
ground_truth.append(true_style)
|
||||
|
||||
# 检查预测是否正确
|
||||
if predicted_style == true_style:
|
||||
correct_predictions += 1
|
||||
|
||||
# 记录详细预测结果
|
||||
self.test_results["predictions"].append({
|
||||
"chat_id": chat_id,
|
||||
"up_content": up_content,
|
||||
"true_style": true_style,
|
||||
"predicted_style": predicted_style,
|
||||
"scores": scores
|
||||
})
|
||||
|
||||
# 计算准确率
|
||||
accuracy = correct_predictions / len(test_data) if test_data else 0
|
||||
|
||||
# 计算其他指标(需要处理None值)
|
||||
valid_predictions = [p for p in predictions if p is not None]
|
||||
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
|
||||
|
||||
if valid_predictions:
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
|
||||
)
|
||||
else:
|
||||
precision = recall = f1 = 0.0
|
||||
|
||||
self.test_results["test_samples"] = len(test_data)
|
||||
self.test_results["accuracy"] = accuracy
|
||||
self.test_results["precision"] = precision
|
||||
self.test_results["recall"] = recall
|
||||
self.test_results["f1_score"] = f1
|
||||
|
||||
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
|
||||
|
||||
def analyze_results(self) -> None:
|
||||
"""分析测试结果"""
|
||||
logger.info("=== 测试结果分析 ===")
|
||||
|
||||
print("\n📊 数据统计:")
|
||||
print(f" 总样本数: {self.test_results['total_samples']}")
|
||||
print(f" 训练样本数: {self.test_results['train_samples']}")
|
||||
print(f" 测试样本数: {self.test_results['test_samples']}")
|
||||
print(f" 唯一风格数: {self.test_results['unique_styles']}")
|
||||
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
|
||||
|
||||
print("\n🎯 模型性能:")
|
||||
print(f" 准确率: {self.test_results['accuracy']:.4f}")
|
||||
print(f" 精确率: {self.test_results['precision']:.4f}")
|
||||
print(f" 召回率: {self.test_results['recall']:.4f}")
|
||||
print(f" F1分数: {self.test_results['f1_score']:.4f}")
|
||||
|
||||
print("\n💾 模型保存:")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
print(f" 保存状态: {save_status}")
|
||||
print(f" 保存路径: {self.test_results['model_save_path']}")
|
||||
|
||||
# 分析各聊天室的性能
|
||||
chat_performance = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
chat_id = pred["chat_id"]
|
||||
if chat_id not in chat_performance:
|
||||
chat_performance[chat_id] = {"correct": 0, "total": 0}
|
||||
|
||||
chat_performance[chat_id]["total"] += 1
|
||||
if pred["predicted_style"] == pred["true_style"]:
|
||||
chat_performance[chat_id]["correct"] += 1
|
||||
|
||||
print("\n📈 各聊天室性能:")
|
||||
for chat_id, perf in chat_performance.items():
|
||||
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
|
||||
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
|
||||
|
||||
# 分析风格分布
|
||||
style_counts = {}
|
||||
for pred in self.test_results["predictions"]:
|
||||
style = pred["true_style"]
|
||||
style_counts[style] = style_counts.get(style, 0) + 1
|
||||
|
||||
print("\n🎨 风格分布 (前10个):")
|
||||
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
|
||||
for style, count in sorted_styles[:10]:
|
||||
print(f" {style}: {count} 次")
|
||||
|
||||
def show_sample_predictions(self, num_samples: int = 10) -> None:
|
||||
"""显示样本预测结果"""
|
||||
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
|
||||
|
||||
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
print(f"\n {i+1}. {status}")
|
||||
print(f" 聊天室: {pred['chat_id']}")
|
||||
print(f" 输入内容: {pred['up_content']}")
|
||||
print(f" 真实风格: {pred['true_style']}")
|
||||
print(f" 预测风格: {pred['predicted_style']}")
|
||||
if pred["scores"]:
|
||||
top_scores = dict(list(pred["scores"].items())[:3])
|
||||
print(f" 分数: {top_scores}")
|
||||
|
||||
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
|
||||
"""保存测试结果到文件"""
|
||||
try:
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
f.write("StyleLearner 数据库测试结果\n")
|
||||
f.write("=" * 50 + "\n\n")
|
||||
|
||||
f.write("数据统计:\n")
|
||||
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
|
||||
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
|
||||
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
|
||||
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
|
||||
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
|
||||
|
||||
f.write("模型性能:\n")
|
||||
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
|
||||
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
|
||||
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
|
||||
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
|
||||
|
||||
f.write("模型保存:\n")
|
||||
save_status = "成功" if self.test_results['model_save_success'] else "失败"
|
||||
f.write(f" 保存状态: {save_status}\n")
|
||||
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
|
||||
|
||||
f.write("详细预测结果:\n")
|
||||
for i, pred in enumerate(self.test_results["predictions"]):
|
||||
status = "✓" if pred["predicted_style"] == pred["true_style"] else "✗"
|
||||
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
|
||||
|
||||
logger.info(f"测试结果已保存到 {output_file}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存测试结果失败: {e}")
|
||||
|
||||
def run_test(self) -> None:
|
||||
"""运行完整测试"""
|
||||
logger.info("开始StyleLearner数据库测试...")
|
||||
|
||||
# 1. 加载数据
|
||||
raw_data = self.load_data_from_database()
|
||||
if not raw_data:
|
||||
logger.error("没有加载到数据,测试终止")
|
||||
return
|
||||
|
||||
# 2. 数据预处理
|
||||
processed_data = self.preprocess_data(raw_data)
|
||||
if not processed_data:
|
||||
logger.error("预处理后没有数据,测试终止")
|
||||
return
|
||||
|
||||
self.test_results["total_samples"] = len(processed_data)
|
||||
|
||||
# 3. 分割数据
|
||||
train_data, test_data = self.split_data(processed_data)
|
||||
|
||||
# 4. 训练模型
|
||||
self.train_model(train_data)
|
||||
|
||||
# 5. 测试模型
|
||||
self.test_model(test_data)
|
||||
|
||||
# 6. 分析结果
|
||||
self.analyze_results()
|
||||
|
||||
# 7. 显示样本预测
|
||||
self.show_sample_predictions(10)
|
||||
|
||||
# 8. 保存结果
|
||||
self.save_results()
|
||||
|
||||
logger.info("StyleLearner数据库测试完成!")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("StyleLearner 数据库测试脚本")
|
||||
print("=" * 50)
|
||||
|
||||
# 创建测试实例
|
||||
test = StyleLearnerDatabaseTest(random_state=42)
|
||||
|
||||
# 运行测试
|
||||
test.run_test()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
view_pkl.py
Normal file
76
view_pkl.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
查看 .pkl 文件内容的工具脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
def view_pkl_file(file_path):
|
||||
"""查看 pkl 文件内容"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print(f"📊 数据类型: {type(data)}")
|
||||
print("=" * 50)
|
||||
|
||||
if isinstance(data, dict):
|
||||
print("🔑 字典键:")
|
||||
for key in data.keys():
|
||||
print(f" - {key}: {type(data[key])}")
|
||||
print()
|
||||
|
||||
print("📋 详细内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
elif isinstance(data, list):
|
||||
print(f"📝 列表长度: {len(data)}")
|
||||
if data:
|
||||
print(f"📊 第一个元素类型: {type(data[0])}")
|
||||
print("📋 前几个元素:")
|
||||
for i, item in enumerate(data[:3]):
|
||||
print(f" [{i}]: {item}")
|
||||
|
||||
else:
|
||||
print("📋 内容:")
|
||||
pprint(data, width=120, depth=10)
|
||||
|
||||
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
|
||||
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
|
||||
print("\n" + "="*50)
|
||||
print("🔍 详细词汇统计 (token_counts):")
|
||||
token_counts = data['nb']['token_counts']
|
||||
for style_id, tokens in token_counts.items():
|
||||
print(f"\n📝 {style_id}:")
|
||||
if tokens:
|
||||
# 按词频排序显示前10个词
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
for word, count in sorted_tokens[:10]:
|
||||
print(f" '{word}': {count}")
|
||||
if len(sorted_tokens) > 10:
|
||||
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_pkl.py <pkl文件路径>")
|
||||
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_pkl_file(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
63
view_tokens.py
Normal file
63
view_tokens.py
Normal file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
专门查看 expressor.pkl 文件中 token_counts 的脚本
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import sys
|
||||
import os
|
||||
|
||||
def view_token_counts(file_path):
|
||||
"""查看 expressor.pkl 文件中的词汇统计"""
|
||||
if not os.path.exists(file_path):
|
||||
print(f"❌ 文件不存在: {file_path}")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
print(f"📁 文件: {file_path}")
|
||||
print("=" * 60)
|
||||
|
||||
if 'nb' not in data or 'token_counts' not in data['nb']:
|
||||
print("❌ 这不是一个 expressor 模型文件")
|
||||
return
|
||||
|
||||
token_counts = data['nb']['token_counts']
|
||||
candidates = data.get('candidates', {})
|
||||
|
||||
print(f"🎯 找到 {len(token_counts)} 个风格")
|
||||
print("=" * 60)
|
||||
|
||||
for style_id, tokens in token_counts.items():
|
||||
style_text = candidates.get(style_id, "未知风格")
|
||||
print(f"\n📝 {style_id}: {style_text}")
|
||||
print(f"📊 词汇数量: {len(tokens)}")
|
||||
|
||||
if tokens:
|
||||
# 按词频排序
|
||||
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
print("🔤 词汇统计 (按频率排序):")
|
||||
for i, (word, count) in enumerate(sorted_tokens):
|
||||
print(f" {i+1:2d}. '{word}': {count}")
|
||||
else:
|
||||
print(" (无词汇数据)")
|
||||
|
||||
print("-" * 40)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 读取文件失败: {e}")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 2:
|
||||
print("用法: python view_tokens.py <expressor.pkl文件路径>")
|
||||
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
|
||||
return
|
||||
|
||||
file_path = sys.argv[1]
|
||||
view_token_counts(file_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user