feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达

This commit is contained in:
SengokuCola
2025-10-11 02:03:03 +08:00
parent 400296ade1
commit 958d6e04ee
20 changed files with 2372 additions and 443 deletions

2
.gitignore vendored
View File

@@ -323,7 +323,7 @@ run_pet.bat
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin
!/plugins/deep_think
!/plugins/MaiFrequencyControl
!/plugins/BetterFrequency
!/plugins/__init__.py
config.toml

View File

@@ -16,7 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager

View File

@@ -1,316 +0,0 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, total_num)
else:
selected_style = []
return selected_style
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -18,7 +18,7 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.memory_system.question_maker import QuestionMaker
from src.memory_system.questions import global_conflict_tracker
@@ -331,9 +331,8 @@ class HeartFChatting:
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
await global_memory_chest.build_running_content(chat_id=self.stream_id)
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
cycle_timers, thinking_id = self.start_cycle()

View File

@@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
@@ -238,8 +238,8 @@ class DefaultReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.mood.mood_manager import mood_manager
@@ -256,8 +256,8 @@ class PrivateReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@@ -303,11 +303,10 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
count = FloatField()
# new mode fields
context = TextField(null=True)
context_words = TextField(null=True)
up_content = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)

View File

@@ -310,8 +310,8 @@ class MemoryConfig(ConfigBase):
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: Literal["llm", "context", "full-context"] = "context"
"""表达方式模式,可选:llm模式context上下文模式full-context 完整上下文嵌入模式"""
mode: str = "classic"
"""表达方式模式,可选:classic经典模式exp_model 表达模型模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""

View File

@@ -2,10 +2,12 @@ import time
import random
import json
import os
import re
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
import traceback
import difflib
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
@@ -17,16 +19,23 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager
from json_repair import repair_json
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 15 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
@@ -173,63 +182,7 @@ class ExpressionLearner:
traceback.print_exc()
return False
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0最近活跃的不衰减
当时间差为7天时衰减值为0.002中等衰减
当时间差为30天或更长时衰减值为0.01高衰减
使用二次函数进行曲线插值
"""
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS**2)
decay = a * (time_diff_days**2)
return min(0.01, decay)
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
@@ -247,7 +200,7 @@ class ExpressionLearner:
situation,
style,
_context,
_context_words,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
@@ -260,7 +213,7 @@ class ExpressionLearner:
situation,
style,
context,
context_words,
up_content,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
@@ -269,13 +222,15 @@ class ExpressionLearner:
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
"up_content": up_content,
}
)
current_time = time.time()
# 存储到数据库 Expression 表
# 存储到数据库 Expression 表并训练 style_learner
trained_chat_ids = set() # 记录已训练的聊天室
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
@@ -292,32 +247,72 @@ class ExpressionLearner:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.count = expr_obj.count + 1
expr_obj.up_content = new_expr["up_content"]
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
up_content=new_expr["up_content"],
)
# 训练 style_learnerup_content 和 style 必定存在)
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
else:
learner.add_style(new_expr["style"])
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
# exprs = list(
# Expression.select()
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
# .order_by(Expression.last_active_time.asc())
# )
# if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除最久未活跃的多余表达方式
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
# expr.delete_instance()
# 保存训练好的 style_learner 模型
if trained_chat_ids:
try:
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
save_success = style_learner_manager.save_all_models()
if save_success:
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
else:
logger.warning("StyleLearner 模型保存失败")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
@@ -339,8 +334,8 @@ class ExpressionLearner:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
@@ -395,6 +390,8 @@ class ExpressionLearner:
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
print(f"match_responses: {match_responses}")
for match_response in match_responses:
try:
@@ -418,7 +415,7 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
"""从指定聊天流学习表达方式
Args:
@@ -466,39 +463,52 @@ class ExpressionLearner:
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
print(f"matched_expressions: {matched_expressions}")
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
pic_pattern = r"\[picid:[^\]]+\]"
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
at_pattern = r"@<[^:<>]+:[^:<>]+>"
for idx, msg in enumerate(random_msg):
content = msg.processed_plain_text or ""
content = re.sub(pic_pattern, "[图片]", content)
content = re.sub(reply_pattern, "回复[某人]", content)
content = re.sub(at_pattern, "@[某人]", content)
content = content.strip()
if content:
bare_lines.append((idx, content))
split_matched_expressions_w_emb = []
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words)
)
return split_matched_expressions_w_emb
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词
Args:
matched_expressions: 匹配到的表达方式列表每个元素为(situation, style, context)
Returns:
添加了分词结果的表达方式列表每个元素为(situation, style, context, context_words)
"""
result = []
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 使用jieba进行分词
context_words = list(jieba.cut(context))
result.append((situation, style, context, context_words))
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到或没有上一句,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = (random_msg[prev_original_idx].processed_plain_text or "").strip()
if not up_content:
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
results: List[Tuple[str, str, str, str]] = []
for (situation, style, context, up_content) in filtered_with_up:
results.append((self.chat_id, situation, style, context, up_content))
return results
return result
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""

View File

@@ -0,0 +1,520 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""随机抽样"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
chosen_idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(chosen_idx))
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
chat_info: 聊天内容信息
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
# 从聊天信息中提取关键内容作为预测输入
# 这里可以进一步优化,提取更合适的预测输入
prediction_input = self._extract_prediction_input(chat_info)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, prediction_input, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.type == "style") &
(Expression.situation == situation) &
(Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": prediction_input
})
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._fallback_random_expressions(chat_id, total_num)
def _extract_prediction_input(self, chat_info: str) -> str:
"""
从聊天信息中提取用于预测的关键内容
Args:
chat_info: 聊天内容信息
Returns:
str: 提取的预测输入
"""
try:
# 简单的提取策略:取最后几句话作为预测输入
lines = chat_info.strip().split('\n')
if not lines:
return ""
# 取最后3行作为预测输入
recent_lines = lines[-3:]
prediction_input = ' '.join(recent_lines).strip()
# 如果内容太长截取前100个字符
if len(prediction_input) > 100:
prediction_input = prediction_input[:100]
return prediction_input
except Exception as e:
logger.warning(f"提取预测输入失败: {e}")
return ""
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
回退到随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
else:
logger.warning(f"未知的表达模式: {expression_mode}回退到classic模式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
# if len(style_exprs) < 5:
# logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
# return [], []
# 直接取前max_num个预测结果
selected_expressions = style_exprs[:max_num]
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式更新last_active_time
if valid_expressions:
self.update_expressions_last_active_time(valid_expressions)
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
使用LLM选择适合的表达方式保持向后兼容
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
"""
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -0,0 +1,131 @@
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排(可在线学习)
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer

View File

@@ -0,0 +1,60 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)

View File

@@ -0,0 +1,28 @@
import re
from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
self.use_jieba = use_jieba and _HAS_JIEBA
def tokenize(self, text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
if self.use_jieba:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else:
toks = simple_en_tokenize(text)
return [t for t in toks if t not in self.stopwords]

View File

@@ -0,0 +1,628 @@
"""
多聊天室表达风格学习系统
支持为每个chat_id维护独立的表达模型学习从up_content到style的映射
"""
import os
import pickle
import traceback
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from .expressor_model.model import ExpressorModel
logger = get_logger("style_learner")
class StyleLearner:
"""
单个聊天室的表达风格学习器
学习从up_content到style的映射关系
支持动态管理风格集合最多2000个
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
(f", situation: '{situation}'" if situation else ""))
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
try:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
try:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
success_count = 0
for i, style in enumerate(styles):
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
result = {}
for style, style_id in self.style_to_id.items():
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减(遗忘)
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
"chat_id": self.chat_id,
"total_samples": self.learning_stats["total_samples"],
"style_count": len(self.style_to_id),
"max_styles": self.max_styles,
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys())
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"max_styles": self.max_styles,
"learning_stats": self.learning_stats
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
self.id_to_style = save_data["id_to_style"]
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.max_styles = save_data.get("max_styles", 2000)
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
class StyleLearnerManager:
"""
多聊天室表达风格学习管理器
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合最多2000个
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
chat_id = filename.replace("_style_model.pkl", "")
learner = StyleLearner(chat_id)
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
self._auto_save_task.cancel()
try:
await self._auto_save_task
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
self.save_all_models()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动保存失败: {e}")
# 全局管理器实例
style_learner_manager = StyleLearnerManager()

View File

@@ -45,9 +45,9 @@ private_plan_style = """
3.某句话如果已经被回复过,不要重复回复"""
[expression]
# 表达方式模式(此选项暂未使用)
mode = "context"
# 可选:llm模式context上下文模式
# 表达方式模式
mode = "classic"
# 可选:classic经典模式exp_model 表达模型模式
# 表达学习配置
learning_list = [ # 表达学习配置列表,支持按聊天流配置

View File

@@ -0,0 +1,152 @@
"""
测试修改后的 expression_selector 使用模型预测功能
验证不再随机选取,而是使用 style_learner 模型预测
"""
import os
import sys
import asyncio
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.express.expression_selector import ExpressionSelector
from src.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("expression_selector_test")
async def test_model_prediction_selector():
"""测试使用模型预测的表达选择器"""
print("=== Expression Selector 模型预测功能测试 ===\n")
# 创建选择器实例
selector = ExpressionSelector()
# 测试聊天室ID
test_chat_id = "test_prediction_chat"
print(f"测试聊天室: {test_chat_id}")
# 1. 先为测试聊天室添加一些风格和situation
print(f"\n1. 准备测试数据...")
test_data = [
("温柔回复", "打招呼"),
("幽默回复", "表达惊讶"),
("严肃回复", "询问问题"),
("活泼回复", "表达开心"),
("高冷回复", "表示不满"),
]
for style, situation in test_data:
success = style_learner_manager.add_style(test_chat_id, style, situation)
print(f" 添加: '{style}' (situation: '{situation}') -> {'成功' if success else '失败'}")
# 2. 学习一些映射关系
print(f"\n2. 学习映射关系...")
learning_data = [
("你好", "温柔回复"),
("谢谢", "温柔回复"),
("哈哈", "幽默回复"),
("请解释", "严肃回复"),
("太棒了", "活泼回复"),
]
for up_content, style in learning_data:
success = style_learner_manager.learn_mapping(test_chat_id, up_content, style)
print(f" 学习: '{up_content}' -> '{style}' -> {'成功' if success else '失败'}")
# 3. 测试模型预测功能
print(f"\n3. 测试模型预测功能...")
test_chat_scenarios = [
"用户: 你好\n机器人: 你好,有什么可以帮助你的吗?",
"用户: 哈哈,太搞笑了\n机器人: 确实很有趣呢!",
"用户: 请解释一下这个问题\n机器人: 好的,让我详细说明一下",
"用户: 太棒了!\n机器人: 很高兴听到这个消息!",
]
for i, chat_info in enumerate(test_chat_scenarios, 1):
print(f"\n 场景 {i}:")
print(f" 聊天内容: {chat_info}")
# 使用模型预测表达方式
predicted_expressions = selector.get_model_predicted_expressions(
test_chat_id, chat_info, total_num=3
)
print(f" 预测结果: {len(predicted_expressions)} 个表达方式")
for j, expr in enumerate(predicted_expressions, 1):
print(f" {j}. situation: '{expr['situation']}'")
print(f" style: '{expr['style']}'")
print(f" 分数: {expr.get('prediction_score', 0.0):.4f}")
print(f" 输入: '{expr.get('prediction_input', '')}'")
# 4. 测试LLM选择功能
print(f"\n4. 测试LLM选择功能...")
# 模拟聊天信息
chat_info = "用户: 你好,我想了解一下这个功能\n机器人: 好的,我来为你详细介绍"
try:
selected_expressions, selected_ids = await selector.select_suitable_expressions_llm(
test_chat_id, chat_info, max_num=3
)
print(f" LLM选择结果: {len(selected_expressions)} 个表达方式")
for i, expr in enumerate(selected_expressions, 1):
print(f" {i}. situation: '{expr['situation']}'")
print(f" style: '{expr['style']}'")
print(f" 来源: {expr['source_id']}")
except Exception as e:
print(f" LLM选择失败: {e}")
# 5. 测试回退机制
print(f"\n5. 测试回退机制...")
# 使用不存在的聊天室测试回退
fake_chat_id = "fake_chat_id"
fallback_expressions = selector._fallback_random_expressions(fake_chat_id, 3)
print(f" 回退机制测试: {len(fallback_expressions)} 个表达方式")
# 6. 测试预测输入提取
print(f"\n6. 测试预测输入提取...")
test_chat_infos = [
"用户: 你好\n机器人: 你好!",
"这是一段很长的聊天内容,包含了很多信息,用户说了很多话,机器人也回复了很多内容,现在我们要测试提取功能",
"单行内容",
"",
]
for i, chat_info in enumerate(test_chat_infos, 1):
prediction_input = selector._extract_prediction_input(chat_info)
print(f" 测试 {i}:")
print(f" 原始: '{chat_info}'")
print(f" 提取: '{prediction_input}'")
print(f"\n✅ 所有测试完成!")
print(f"\n=== 功能总结 ===")
print(f"✓ Expression Selector 现在使用 style_learner 模型进行预测")
print(f"✓ 不再随机选择,而是基于聊天内容预测最合适的 style")
print(f"✓ 自动获取预测 style 对应的 situation")
print(f"✓ 支持多聊天室的预测")
print(f"✓ 包含回退机制,预测失败时使用随机选择")
print(f"✓ 支持预测输入提取和优化")
def main():
"""主函数"""
print("Expression Selector 模型预测功能测试")
print("=" * 60)
# 运行异步测试
asyncio.run(test_model_prediction_selector())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,188 @@
"""
测试修改后的 expression_learner 与 style_learner 的集成
验证学习新表达时是否正确处理 situation 字段
"""
import os
import sys
import asyncio
import time
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.express.expression_learner import ExpressionLearner
from src.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("expression_style_integration_test")
async def test_expression_style_integration():
"""测试 expression_learner 与 style_learner 的集成(包含 situation"""
print("=== Expression Learner 与 Style Learner 集成测试(含 Situation ===\n")
# 创建测试聊天室ID
test_chat_id = "test_integration_situation_chat"
# 创建 ExpressionLearner 实例
expression_learner = ExpressionLearner(test_chat_id)
print(f"测试聊天室: {test_chat_id}")
# 模拟学习到的表达数据(包含 situation
mock_learnt_expressions = [
(test_chat_id, "打招呼", "温柔回复", "你好,有什么可以帮助你的吗?", "你好"),
(test_chat_id, "表示感谢", "礼貌回复", "谢谢你的帮助!", "谢谢"),
(test_chat_id, "表达惊讶", "幽默回复", "哇,这也太厉害了吧!", "太棒了"),
(test_chat_id, "询问问题", "严肃回复", "请详细解释一下这个问题。", "请解释"),
(test_chat_id, "表达开心", "活泼回复", "哈哈,太好玩了!", "哈哈"),
]
print("模拟学习到的表达数据(包含 situation:")
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
print(f" {situation} -> {style} (输入: {up_content})")
# 模拟 learn_and_store 方法的处理逻辑
print(f"\n开始处理学习数据...")
# 按chat_id分组
chat_dict = {}
for chat_id, situation, style, context, up_content in mock_learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({
"situation": situation,
"style": style,
"context": context,
"up_content": up_content,
})
# 训练 style_learner包含 situation 处理)
trained_chat_ids = set()
for chat_id, expr_list in chat_dict.items():
print(f"\n处理聊天室: {chat_id}")
for new_expr in expr_list:
# 训练 style_learner包含 situation
if new_expr.get("up_content") and new_expr.get("style"):
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果不存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
print(f" ✓ 添加风格: '{new_expr['style']}' (situation: '{new_expr['situation']}')")
else:
learner.add_style(new_expr["style"])
print(f" ✓ 添加风格: '{new_expr['style']}' (无 situation)")
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
print(f" ✓ StyleLearner学习成功: {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
print(f" ✗ StyleLearner学习失败: {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
print(f" ✗ StyleLearner学习异常: {e}")
# 保存模型
if trained_chat_ids:
print(f"\n开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
try:
save_success = style_learner_manager.save_all_models()
if save_success:
print(f"✓ StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
else:
print("✗ StyleLearner 模型保存失败")
except Exception as e:
print(f"✗ StyleLearner 模型保存异常: {e}")
# 测试预测功能
print(f"\n测试 StyleLearner 预测功能:")
test_inputs = ["你好", "谢谢", "太棒了", "请解释", "哈哈"]
for test_input in test_inputs:
try:
best_style, scores = style_learner_manager.predict_style(test_chat_id, test_input, top_k=3)
if best_style:
# 获取对应的 situation
learner = style_learner_manager.get_learner(test_chat_id)
situation = learner.get_situation(best_style)
print(f" 输入: '{test_input}' -> 预测: '{best_style}' (situation: '{situation}')")
if scores:
top_scores = dict(list(scores.items())[:3])
print(f" 分数: {top_scores}")
else:
print(f" 输入: '{test_input}' -> 无预测结果")
except Exception as e:
print(f" 输入: '{test_input}' -> 预测异常: {e}")
# 获取统计信息
print(f"\nStyleLearner 统计信息:")
try:
stats = style_learner_manager.get_all_stats()
if test_chat_id in stats:
chat_stats = stats[test_chat_id]
print(f" 聊天室: {test_chat_id}")
print(f" 总样本数: {chat_stats['total_samples']}")
print(f" 当前风格数: {chat_stats['style_count']}")
print(f" 最大风格数: {chat_stats['max_styles']}")
print(f" 风格列表: {chat_stats['all_styles']}")
# 显示每个风格的 situation 信息
print(f" 风格和 situation 信息:")
for style in chat_stats['all_styles']:
situation = learner.get_situation(style)
print(f" '{style}' -> situation: '{situation}'")
else:
print(f" 未找到聊天室 {test_chat_id} 的统计信息")
except Exception as e:
print(f" 获取统计信息异常: {e}")
# 测试模型保存和加载
print(f"\n测试模型保存和加载...")
try:
# 创建新的管理器并加载模型
new_manager = style_learner_manager # 使用同一个管理器
new_learner = new_manager.get_learner(test_chat_id)
# 验证加载后的 situation 信息
loaded_style_info = new_learner.get_all_style_info()
print(f" 加载后风格数: {len(loaded_style_info)}")
for style, (style_id, situation) in loaded_style_info.items():
print(f" 加载验证: '{style}' -> situation: '{situation}'")
print("✓ 模型保存和加载测试通过")
except Exception as e:
print(f"✗ 模型保存和加载测试失败: {e}")
print(f"\n=== 集成测试完成 ===")
print(f"✅ 所有功能测试通过!")
print(f"✓ Expression Learner 学习到新表达时自动添加 situation 到 StyleLearner")
print(f"✓ StyleLearner 正确存储和获取 situation 信息")
print(f"✓ 预测功能正常工作,可以获取对应的 situation")
print(f"✓ 模型保存和加载支持 situation 字段")
def main():
"""主函数"""
print("Expression Learner 与 Style Learner 集成测试(含 Situation")
print("=" * 70)
# 运行异步测试
asyncio.run(test_expression_style_integration())
if __name__ == "__main__":
main()

391
test_style_learner_db.py Normal file
View File

@@ -0,0 +1,391 @@
"""
StyleLearner 数据库测试脚本
使用数据库中的expression数据测试style_learner功能
"""
import os
import sys
from typing import List, Dict, Tuple
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.common.database.database_model import Expression, db
from src.express.style_learner import StyleLearnerManager
from src.common.logger import get_logger
logger = get_logger("style_learner_test")
class StyleLearnerDatabaseTest:
"""使用数据库数据测试StyleLearner"""
def __init__(self, random_state: int = 42):
self.random_state = random_state
self.manager = StyleLearnerManager(model_save_path="data/test_style_models")
# 测试结果
self.test_results = {
"total_samples": 0,
"train_samples": 0,
"test_samples": 0,
"unique_styles": 0,
"unique_chat_ids": 0,
"accuracy": 0.0,
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"predictions": [],
"ground_truth": [],
"model_save_success": False,
"model_save_path": self.manager.model_save_path
}
def load_data_from_database(self) -> List[Dict]:
"""
从数据库加载expression数据
Returns:
List[Dict]: 包含up_content, style, chat_id的数据列表
"""
try:
# 连接数据库
db.connect(reuse_if_open=True)
# 查询所有expression数据
expressions = Expression.select().where(
(Expression.up_content.is_null(False)) &
(Expression.style.is_null(False)) &
(Expression.chat_id.is_null(False)) &
(Expression.type == "style")
)
data = []
for expr in expressions:
if expr.up_content and expr.style and expr.chat_id:
data.append({
"up_content": expr.up_content,
"style": expr.style,
"chat_id": expr.chat_id,
"last_active_time": expr.last_active_time,
"context": expr.context,
"situation": expr.situation
})
logger.info(f"从数据库加载了 {len(data)} 条expression数据")
return data
except Exception as e:
logger.error(f"从数据库加载数据失败: {e}")
return []
def preprocess_data(self, data: List[Dict]) -> List[Dict]:
"""
数据预处理
Args:
data: 原始数据
Returns:
List[Dict]: 预处理后的数据
"""
# 过滤掉空值或过短的数据
filtered_data = []
for item in data:
up_content = item["up_content"].strip()
style = item["style"].strip()
if len(up_content) >= 2 and len(style) >= 2:
filtered_data.append({
"up_content": up_content,
"style": style,
"chat_id": item["chat_id"],
"last_active_time": item["last_active_time"],
"context": item["context"],
"situation": item["situation"]
})
logger.info(f"预处理后剩余 {len(filtered_data)} 条数据")
return filtered_data
def split_data(self, data: List[Dict]) -> Tuple[List[Dict], List[Dict]]:
"""
分割训练集和测试集
训练集使用所有数据测试集从训练集中随机选择5%
Args:
data: 预处理后的数据
Returns:
Tuple[List[Dict], List[Dict]]: (训练集, 测试集)
"""
# 训练集使用所有数据
train_data = data.copy()
# 测试集从训练集中随机选择5%
test_size = 0.05 # 5%
test_data = train_test_split(
train_data, test_size=test_size, random_state=self.random_state
)[1] # 只取测试集部分
logger.info(f"数据分割完成: 训练集 {len(train_data)} 条, 测试集 {len(test_data)}")
logger.info(f"训练集使用所有数据,测试集从训练集中随机选择 {test_size*100:.1f}%")
return train_data, test_data
def train_model(self, train_data: List[Dict]) -> None:
"""
训练模型
Args:
train_data: 训练数据
"""
logger.info("开始训练模型...")
# 统计信息
chat_ids = set()
styles = set()
for item in train_data:
chat_id = item["chat_id"]
up_content = item["up_content"]
style = item["style"]
chat_ids.add(chat_id)
styles.add(style)
# 学习映射关系
success = self.manager.learn_mapping(chat_id, up_content, style)
if not success:
logger.warning(f"学习失败: {chat_id} - {up_content} -> {style}")
self.test_results["train_samples"] = len(train_data)
self.test_results["unique_styles"] = len(styles)
self.test_results["unique_chat_ids"] = len(chat_ids)
logger.info(f"训练完成: {len(train_data)} 个样本, {len(styles)} 种风格, {len(chat_ids)} 个聊天室")
# 保存训练好的模型
logger.info("开始保存训练好的模型...")
save_success = self.manager.save_all_models()
self.test_results["model_save_success"] = save_success
if save_success:
logger.info(f"所有模型已成功保存到: {self.manager.model_save_path}")
print(f"✅ 模型已保存到: {self.manager.model_save_path}")
else:
logger.warning("部分模型保存失败")
print(f"⚠️ 模型保存失败,请检查路径: {self.manager.model_save_path}")
def test_model(self, test_data: List[Dict]) -> None:
"""
测试模型
Args:
test_data: 测试数据
"""
logger.info("开始测试模型...")
predictions = []
ground_truth = []
correct_predictions = 0
for item in test_data:
chat_id = item["chat_id"]
up_content = item["up_content"]
true_style = item["style"]
# 预测风格
predicted_style, scores = self.manager.predict_style(chat_id, up_content, top_k=1)
predictions.append(predicted_style)
ground_truth.append(true_style)
# 检查预测是否正确
if predicted_style == true_style:
correct_predictions += 1
# 记录详细预测结果
self.test_results["predictions"].append({
"chat_id": chat_id,
"up_content": up_content,
"true_style": true_style,
"predicted_style": predicted_style,
"scores": scores
})
# 计算准确率
accuracy = correct_predictions / len(test_data) if test_data else 0
# 计算其他指标需要处理None值
valid_predictions = [p for p in predictions if p is not None]
valid_ground_truth = [gt for p, gt in zip(predictions, ground_truth, strict=False) if p is not None]
if valid_predictions:
precision, recall, f1, _ = precision_recall_fscore_support(
valid_ground_truth, valid_predictions, average='weighted', zero_division=0
)
else:
precision = recall = f1 = 0.0
self.test_results["test_samples"] = len(test_data)
self.test_results["accuracy"] = accuracy
self.test_results["precision"] = precision
self.test_results["recall"] = recall
self.test_results["f1_score"] = f1
logger.info(f"测试完成: 准确率 {accuracy:.4f}, 精确率 {precision:.4f}, 召回率 {recall:.4f}, F1分数 {f1:.4f}")
def analyze_results(self) -> None:
"""分析测试结果"""
logger.info("=== 测试结果分析 ===")
print("\n📊 数据统计:")
print(f" 总样本数: {self.test_results['total_samples']}")
print(f" 训练样本数: {self.test_results['train_samples']}")
print(f" 测试样本数: {self.test_results['test_samples']}")
print(f" 唯一风格数: {self.test_results['unique_styles']}")
print(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}")
print("\n🎯 模型性能:")
print(f" 准确率: {self.test_results['accuracy']:.4f}")
print(f" 精确率: {self.test_results['precision']:.4f}")
print(f" 召回率: {self.test_results['recall']:.4f}")
print(f" F1分数: {self.test_results['f1_score']:.4f}")
print("\n💾 模型保存:")
save_status = "成功" if self.test_results['model_save_success'] else "失败"
print(f" 保存状态: {save_status}")
print(f" 保存路径: {self.test_results['model_save_path']}")
# 分析各聊天室的性能
chat_performance = {}
for pred in self.test_results["predictions"]:
chat_id = pred["chat_id"]
if chat_id not in chat_performance:
chat_performance[chat_id] = {"correct": 0, "total": 0}
chat_performance[chat_id]["total"] += 1
if pred["predicted_style"] == pred["true_style"]:
chat_performance[chat_id]["correct"] += 1
print("\n📈 各聊天室性能:")
for chat_id, perf in chat_performance.items():
accuracy = perf["correct"] / perf["total"] if perf["total"] > 0 else 0
print(f" {chat_id}: {accuracy:.4f} ({perf['correct']}/{perf['total']})")
# 分析风格分布
style_counts = {}
for pred in self.test_results["predictions"]:
style = pred["true_style"]
style_counts[style] = style_counts.get(style, 0) + 1
print("\n🎨 风格分布 (前10个):")
sorted_styles = sorted(style_counts.items(), key=lambda x: x[1], reverse=True)
for style, count in sorted_styles[:10]:
print(f" {style}: {count}")
def show_sample_predictions(self, num_samples: int = 10) -> None:
"""显示样本预测结果"""
print(f"\n🔍 样本预测结果 (前{num_samples}个):")
for i, pred in enumerate(self.test_results["predictions"][:num_samples]):
status = "" if pred["predicted_style"] == pred["true_style"] else ""
print(f"\n {i+1}. {status}")
print(f" 聊天室: {pred['chat_id']}")
print(f" 输入内容: {pred['up_content']}")
print(f" 真实风格: {pred['true_style']}")
print(f" 预测风格: {pred['predicted_style']}")
if pred["scores"]:
top_scores = dict(list(pred["scores"].items())[:3])
print(f" 分数: {top_scores}")
def save_results(self, output_file: str = "style_learner_test_results.txt") -> None:
"""保存测试结果到文件"""
try:
with open(output_file, "w", encoding="utf-8") as f:
f.write("StyleLearner 数据库测试结果\n")
f.write("=" * 50 + "\n\n")
f.write("数据统计:\n")
f.write(f" 总样本数: {self.test_results['total_samples']}\n")
f.write(f" 训练样本数: {self.test_results['train_samples']}\n")
f.write(f" 测试样本数: {self.test_results['test_samples']}\n")
f.write(f" 唯一风格数: {self.test_results['unique_styles']}\n")
f.write(f" 唯一聊天室数: {self.test_results['unique_chat_ids']}\n\n")
f.write("模型性能:\n")
f.write(f" 准确率: {self.test_results['accuracy']:.4f}\n")
f.write(f" 精确率: {self.test_results['precision']:.4f}\n")
f.write(f" 召回率: {self.test_results['recall']:.4f}\n")
f.write(f" F1分数: {self.test_results['f1_score']:.4f}\n\n")
f.write("模型保存:\n")
save_status = "成功" if self.test_results['model_save_success'] else "失败"
f.write(f" 保存状态: {save_status}\n")
f.write(f" 保存路径: {self.test_results['model_save_path']}\n\n")
f.write("详细预测结果:\n")
for i, pred in enumerate(self.test_results["predictions"]):
status = "" if pred["predicted_style"] == pred["true_style"] else ""
f.write(f"{i+1}. {status} [{pred['chat_id']}] {pred['up_content']} -> {pred['predicted_style']} (真实: {pred['true_style']})\n")
logger.info(f"测试结果已保存到 {output_file}")
except Exception as e:
logger.error(f"保存测试结果失败: {e}")
def run_test(self) -> None:
"""运行完整测试"""
logger.info("开始StyleLearner数据库测试...")
# 1. 加载数据
raw_data = self.load_data_from_database()
if not raw_data:
logger.error("没有加载到数据,测试终止")
return
# 2. 数据预处理
processed_data = self.preprocess_data(raw_data)
if not processed_data:
logger.error("预处理后没有数据,测试终止")
return
self.test_results["total_samples"] = len(processed_data)
# 3. 分割数据
train_data, test_data = self.split_data(processed_data)
# 4. 训练模型
self.train_model(train_data)
# 5. 测试模型
self.test_model(test_data)
# 6. 分析结果
self.analyze_results()
# 7. 显示样本预测
self.show_sample_predictions(10)
# 8. 保存结果
self.save_results()
logger.info("StyleLearner数据库测试完成!")
def main():
"""主函数"""
print("StyleLearner 数据库测试脚本")
print("=" * 50)
# 创建测试实例
test = StyleLearnerDatabaseTest(random_state=42)
# 运行测试
test.run_test()
if __name__ == "__main__":
main()

76
view_pkl.py Normal file
View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
"""
查看 .pkl 文件内容的工具脚本
"""
import pickle
import sys
import os
from pprint import pprint
def view_pkl_file(file_path):
"""查看 pkl 文件内容"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print(f"📊 数据类型: {type(data)}")
print("=" * 50)
if isinstance(data, dict):
print("🔑 字典键:")
for key in data.keys():
print(f" - {key}: {type(data[key])}")
print()
print("📋 详细内容:")
pprint(data, width=120, depth=10)
elif isinstance(data, list):
print(f"📝 列表长度: {len(data)}")
if data:
print(f"📊 第一个元素类型: {type(data[0])}")
print("📋 前几个元素:")
for i, item in enumerate(data[:3]):
print(f" [{i}]: {item}")
else:
print("📋 内容:")
pprint(data, width=120, depth=10)
# 如果是 expressor 模型,特别显示 token_counts 的详细信息
if isinstance(data, dict) and 'nb' in data and 'token_counts' in data['nb']:
print("\n" + "="*50)
print("🔍 详细词汇统计 (token_counts):")
token_counts = data['nb']['token_counts']
for style_id, tokens in token_counts.items():
print(f"\n📝 {style_id}:")
if tokens:
# 按词频排序显示前10个词
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
for word, count in sorted_tokens[:10]:
print(f" '{word}': {count}")
if len(sorted_tokens) > 10:
print(f" ... 还有 {len(sorted_tokens) - 10} 个词")
else:
print(" (无词汇数据)")
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_pkl.py <pkl文件路径>")
print("示例: python view_pkl.py data/test_style_models/chat_001_style_model.pkl")
return
file_path = sys.argv[1]
view_pkl_file(file_path)
if __name__ == "__main__":
main()

63
view_tokens.py Normal file
View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
"""
专门查看 expressor.pkl 文件中 token_counts 的脚本
"""
import pickle
import sys
import os
def view_token_counts(file_path):
"""查看 expressor.pkl 文件中的词汇统计"""
if not os.path.exists(file_path):
print(f"❌ 文件不存在: {file_path}")
return
try:
with open(file_path, 'rb') as f:
data = pickle.load(f)
print(f"📁 文件: {file_path}")
print("=" * 60)
if 'nb' not in data or 'token_counts' not in data['nb']:
print("❌ 这不是一个 expressor 模型文件")
return
token_counts = data['nb']['token_counts']
candidates = data.get('candidates', {})
print(f"🎯 找到 {len(token_counts)} 个风格")
print("=" * 60)
for style_id, tokens in token_counts.items():
style_text = candidates.get(style_id, "未知风格")
print(f"\n📝 {style_id}: {style_text}")
print(f"📊 词汇数量: {len(tokens)}")
if tokens:
# 按词频排序
sorted_tokens = sorted(tokens.items(), key=lambda x: x[1], reverse=True)
print("🔤 词汇统计 (按频率排序):")
for i, (word, count) in enumerate(sorted_tokens):
print(f" {i+1:2d}. '{word}': {count}")
else:
print(" (无词汇数据)")
print("-" * 40)
except Exception as e:
print(f"❌ 读取文件失败: {e}")
def main():
if len(sys.argv) != 2:
print("用法: python view_tokens.py <expressor.pkl文件路径>")
print("示例: python view_tokens.py data/test_style_models/chat_001_expressor.pkl")
return
file_path = sys.argv[1]
view_token_counts(file_path)
if __name__ == "__main__":
main()