feat:表达方式更新,现在会训练朴素贝叶斯模型来预测使用什么表达

This commit is contained in:
SengokuCola
2025-10-11 02:03:03 +08:00
parent 400296ade1
commit 958d6e04ee
20 changed files with 2372 additions and 443 deletions

View File

@@ -16,7 +16,7 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager

View File

@@ -1,316 +0,0 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], weights: List[float], k: int) -> List[Dict]:
"""按权重随机抽样"""
if not population or not weights or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用累积权重的方法进行加权抽样
selected = []
population_copy = population.copy()
weights_copy = weights.copy()
for _ in range(k):
if not population_copy:
break
# 选择一个元素
chosen_idx = random.choices(range(len(population_copy)), weights=weights_copy)[0]
selected.append(population_copy.pop(chosen_idx))
weights_copy.pop(chosen_idx)
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"count": expr.count,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 按权重抽样使用count作为权重
if style_exprs:
style_weights = [expr.get("count", 1) for expr in style_exprs]
selected_style = weighted_sample(style_exprs, style_weights, total_num)
else:
selected_style = []
return selected_style
def update_expressions_count_batch(self, expressions_to_update: List[Dict[str, Any]], increment: float = 0.1):
"""对一批表达方式更新count值按chat_id+type分组后一次性写入数据库"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
current_count = expr_obj.count
new_count = min(current_count + increment, 5.0)
expr_obj.count = new_count
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
f"表达方式激活: 原count={current_count:.3f}, 增量={increment}, 新count={new_count:.3f} in db"
)
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
# sourcery skip: inline-variable, list-comprehension
"""使用LLM选择适合的表达方式"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 1. 获取20个随机表达方式现在按权重抽取
style_exprs = self.get_random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
try:
# start_time = time.time()
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"LLM请求时间: {model_name} {time.time() - start_time} \n{prompt}")
# logger.info(f"模型名称: {model_name}")
# logger.info(f"LLM返回结果: {content}")
# if reasoning_content:
# logger.info(f"LLM推理: {reasoning_content}")
# else:
# logger.info(f"LLM推理: 无")
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式一次性更新count数
if valid_expressions:
self.update_expressions_count_batch(valid_expressions, 0.006)
# logger.info(f"LLM从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"LLM处理表达方式选择时出错: {e}")
return [], []
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -18,7 +18,7 @@ from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.chat.heart_flow.hfc_utils import send_typing, stop_typing
from src.chat.express.expression_learner import expression_learner_manager
from src.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.memory_system.question_maker import QuestionMaker
from src.memory_system.questions import global_conflict_tracker
@@ -331,9 +331,8 @@ class HeartFChatting:
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
await self.expression_learner.trigger_learning_for_chat()
await global_memory_chest.build_running_content(chat_id=self.stream_id)
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(global_memory_chest.build_running_content(chat_id=self.stream_id))
cycle_timers, thinking_id = self.start_cycle()

View File

@@ -26,7 +26,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
@@ -238,8 +238,8 @@ class DefaultReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@@ -24,7 +24,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.express.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.mood.mood_manager import mood_manager
@@ -256,8 +256,8 @@ class PrivateReplyer:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# LLM模式调用LLM选择5-10个然后随机选5个
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions_llm(
# 根据配置模式选择表达方式exp_model模式直接使用模型预测classic模式使用LLM选择
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target
)

View File

@@ -303,11 +303,10 @@ class Expression(BaseModel):
situation = TextField()
style = TextField()
count = FloatField()
# new mode fields
context = TextField(null=True)
context_words = TextField(null=True)
up_content = TextField(null=True)
last_active_time = FloatField()
chat_id = TextField(index=True)

View File

@@ -310,8 +310,8 @@ class MemoryConfig(ConfigBase):
class ExpressionConfig(ConfigBase):
"""表达配置类"""
mode: Literal["llm", "context", "full-context"] = "context"
"""表达方式模式,可选:llm模式context上下文模式full-context 完整上下文嵌入模式"""
mode: str = "classic"
"""表达方式模式,可选:classic经典模式exp_model 表达模型模式"""
learning_list: list[list] = field(default_factory=lambda: [])
"""

View File

@@ -2,10 +2,12 @@ import time
import random
import json
import os
import re
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
import traceback
import difflib
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
@@ -17,16 +19,23 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.style_learner import style_learner_manager
from json_repair import repair_json
MAX_EXPRESSION_COUNT = 300
DECAY_DAYS = 15 # 30天衰减到0.01
DECAY_MIN = 0.01 # 最小衰减值
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
@@ -173,63 +182,7 @@ class ExpressionLearner:
traceback.print_exc()
return False
def _apply_global_decay_to_database(self, current_time: float) -> None:
"""
对数据库中的所有表达方式应用全局衰减
"""
try:
# 获取所有表达方式
all_expressions = Expression.select()
updated_count = 0
deleted_count = 0
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01:
# 如果count太小删除这个表达方式
expr.delete_instance()
deleted_count += 1
else:
# 更新count
expr.count = new_count
expr.save()
updated_count += 1
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
def calculate_decay_factor(self, time_diff_days: float) -> float:
"""
计算衰减值
当时间差为0天时衰减值为0最近活跃的不衰减
当时间差为7天时衰减值为0.002中等衰减
当时间差为30天或更长时衰减值为0.01高衰减
使用二次函数进行曲线插值
"""
if time_diff_days <= 0:
return 0.0 # 刚激活的表达式不衰减
if time_diff_days >= DECAY_DAYS:
return 0.01 # 长时间未活跃的表达式大幅衰减
# 使用二次函数插值在0-30天之间从0衰减到0.01
# 使用简单的二次函数y = a * x^2
# 当x=30时y=0.01,所以 a = 0.01 / (30^2) = 0.01 / 900
a = 0.01 / (DECAY_DAYS**2)
decay = a * (time_diff_days**2)
return min(0.01, decay)
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
@@ -247,7 +200,7 @@ class ExpressionLearner:
situation,
style,
_context,
_context_words,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
@@ -260,7 +213,7 @@ class ExpressionLearner:
situation,
style,
context,
context_words,
up_content,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
@@ -269,13 +222,15 @@ class ExpressionLearner:
"situation": situation,
"style": style,
"context": context,
"context_words": context_words,
"up_content": up_content,
}
)
current_time = time.time()
# 存储到数据库 Expression 表
# 存储到数据库 Expression 表并训练 style_learner
trained_chat_ids = set() # 记录已训练的聊天室
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
@@ -292,32 +247,72 @@ class ExpressionLearner:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.count = expr_obj.count + 1
expr_obj.up_content = new_expr["up_content"]
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
count=1,
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
up_content=new_expr["up_content"],
)
# 训练 style_learnerup_content 和 style 必定存在)
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
else:
learner.add_style(new_expr["style"])
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
# 限制最大数量
exprs = list(
Expression.select()
.where((Expression.chat_id == chat_id) & (Expression.type == "style"))
.order_by(Expression.count.asc())
)
if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除count最小的多余表达方式
for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
expr.delete_instance()
# exprs = list(
# Expression.select()
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
# .order_by(Expression.last_active_time.asc())
# )
# if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除最久未活跃的多余表达方式
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
# expr.delete_instance()
# 保存训练好的 style_learner 模型
if trained_chat_ids:
try:
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
save_success = style_learner_manager.save_all_models()
if save_success:
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
else:
logger.warning("StyleLearner 模型保存失败")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
@@ -339,8 +334,8 @@ class ExpressionLearner:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
@@ -395,6 +390,8 @@ class ExpressionLearner:
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
print(f"match_responses: {match_responses}")
for match_response in match_responses:
try:
@@ -418,7 +415,7 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str]]]]:
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
"""从指定聊天流学习表达方式
Args:
@@ -466,39 +463,52 @@ class ExpressionLearner:
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
print(f"matched_expressions: {matched_expressions}")
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(
matched_expressions
)
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
pic_pattern = r"\[picid:[^\]]+\]"
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
at_pattern = r"@<[^:<>]+:[^:<>]+>"
for idx, msg in enumerate(random_msg):
content = msg.processed_plain_text or ""
content = re.sub(pic_pattern, "[图片]", content)
content = re.sub(reply_pattern, "回复[某人]", content)
content = re.sub(at_pattern, "@[某人]", content)
content = content.strip()
if content:
bare_lines.append((idx, content))
split_matched_expressions_w_emb = []
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append(
(self.chat_id, situation, style, context, context_words)
)
return split_matched_expressions_w_emb
def split_expression_context(
self, matched_expressions: List[Tuple[str, str, str]]
) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词
Args:
matched_expressions: 匹配到的表达方式列表每个元素为(situation, style, context)
Returns:
添加了分词结果的表达方式列表每个元素为(situation, style, context, context_words)
"""
result = []
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 使用jieba进行分词
context_words = list(jieba.cut(context))
result.append((situation, style, context, context_words))
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到或没有上一句,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = (random_msg[prev_original_idx].processed_plain_text or "").strip()
if not up_content:
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
results: List[Tuple[str, str, str, str]] = []
for (situation, style, context, up_content) in filtered_with_up:
results.append((self.chat_id, situation, style, context, up_content))
return results
return result
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""

View File

@@ -0,0 +1,520 @@
import json
import time
import random
import hashlib
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
logger = get_logger("expression_selector")
def init_prompt():
expression_evaluation_prompt = """
以下是正在进行的聊天内容:
{chat_observe_info}
你的名字是{bot_name}{target_message}
以下是可选的表达情境:
{all_situations}
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
考虑因素包括:
1. 聊天的情绪氛围(轻松、严肃、幽默等)
2. 话题类型(日常、技术、游戏、情感等)
3. 情境与当前语境的匹配度
{target_message_extra_block}
请以JSON格式输出只需要输出选中的情境编号
例如:
{{
"selected_situations": [2, 3, 5, 7, 19]
}}
请严格按照JSON格式输出不要包含其他内容
"""
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""随机抽样"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
chosen_idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(chosen_idx))
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id与get_stream_id一致"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
chat_info: 聊天内容信息
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
# 从聊天信息中提取关键内容作为预测输入
# 这里可以进一步优化,提取更合适的预测输入
prediction_input = self._extract_prediction_input(chat_info)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, prediction_input, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.type == "style") &
(Expression.situation == situation) &
(Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": prediction_input
})
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"为聊天室 {chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._fallback_random_expressions(chat_id, total_num)
def _extract_prediction_input(self, chat_info: str) -> str:
"""
从聊天信息中提取用于预测的关键内容
Args:
chat_info: 聊天内容信息
Returns:
str: 提取的预测输入
"""
try:
# 简单的提取策略:取最后几句话作为预测输入
lines = chat_info.strip().split('\n')
if not lines:
return ""
# 取最后3行作为预测输入
recent_lines = lines[-3:]
prediction_input = ' '.join(recent_lines).strip()
# 如果内容太长截取前100个字符
if len(prediction_input) > 100:
prediction_input = prediction_input[:100]
return prediction_input
except Exception as e:
logger.warning(f"提取预测输入失败: {e}")
return ""
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
回退到随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
else:
logger.warning(f"未知的表达模式: {expression_mode}回退到classic模式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message)
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
# if len(style_exprs) < 5:
# logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
# return [], []
# 直接取前max_num个预测结果
selected_expressions = style_exprs[:max_num]
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
return [], []
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要回复消息:{target_message}"
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
bot_name=global_config.bot.nickname,
chat_observe_info=chat_info,
all_situations=all_situations_str,
max_num=max_num,
target_message=target_message_str,
target_message_extra_block=target_message_extra_block,
)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式更新last_active_time
if valid_expressions:
self.update_expressions_last_active_time(valid_expressions)
logger.info(f"classic模式从{len(all_expressions)}个情境中选择了{len(valid_expressions)}")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
使用LLM选择适合的表达方式保持向后兼容
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
"""
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -0,0 +1,131 @@
from typing import Dict, Optional, Tuple, List
from collections import Counter, defaultdict
import pickle
import os
from .tokenizer import Tokenizer
from .online_nb import OnlineNaiveBayes
class ExpressorModel:
"""
直接使用朴素贝叶斯精排(可在线学习)
支持存储situation字段不参与计算仅与style对应
"""
def __init__(self,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
vocab_size: int = 200000,
use_jieba: bool = True):
self.tokenizer = Tokenizer(stopwords=set(), use_jieba=use_jieba)
self.nb = OnlineNaiveBayes(alpha=alpha, beta=beta, gamma=gamma, vocab_size=vocab_size)
self._candidates: Dict[str, str] = {} # cid -> text (style)
self._situations: Dict[str, str] = {} # cid -> situation (不参与计算)
def add_candidate(self, cid: str, text: str, situation: str = None):
"""添加候选文本和对应的situation"""
self._candidates[cid] = text
if situation is not None:
self._situations[cid] = situation
# 确保在nb模型中初始化该候选的计数
if cid not in self.nb.cls_counts:
self.nb.cls_counts[cid] = 0.0
if cid not in self.nb.token_counts:
self.nb.token_counts[cid] = defaultdict(float)
def add_candidates_bulk(self, items: List[Tuple[str, str]], situations: List[str] = None):
"""批量添加候选文本和对应的situations"""
for i, (cid, text) in enumerate(items):
situation = situations[i] if situations and i < len(situations) else None
self.add_candidate(cid, text, situation)
def predict(self, text: str, k: int = None) -> Tuple[Optional[str], Dict[str, float]]:
"""直接对所有候选进行朴素贝叶斯评分"""
toks = self.tokenizer.tokenize(text)
if not toks:
return None, {}
if not self._candidates:
return None, {}
# 对所有候选进行评分
tf = Counter(toks)
all_cids = list(self._candidates.keys())
scores = self.nb.score_batch(tf, all_cids)
# 取最高分
if not scores:
return None, {}
best = max(scores.items(), key=lambda x: x[1])[0]
return best, scores
def update_positive(self, text: str, cid: str):
"""更新正反馈学习"""
toks = self.tokenizer.tokenize(text)
if not toks:
return
tf = Counter(toks)
self.nb.update_positive(tf, cid)
def decay(self, factor: float):
self.nb.decay(factor=factor)
def get_situation(self, cid: str) -> Optional[str]:
"""获取候选对应的situation"""
return self._situations.get(cid)
def get_style(self, cid: str) -> Optional[str]:
"""获取候选对应的style"""
return self._candidates.get(cid)
def get_candidate_info(self, cid: str) -> Tuple[Optional[str], Optional[str]]:
"""获取候选的style和situation信息"""
return self._candidates.get(cid), self._situations.get(cid)
def get_all_candidates(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""获取所有候选的style和situation信息"""
return {cid: (style, self._situations.get(cid))
for cid, style in self._candidates.items()}
def save(self, path: str):
"""保存模型"""
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as f:
pickle.dump({
"candidates": self._candidates,
"situations": self._situations,
"nb": {
"cls_counts": dict(self.nb.cls_counts),
"token_counts": {cid: dict(tc) for cid, tc in self.nb.token_counts.items()},
"alpha": self.nb.alpha,
"beta": self.nb.beta,
"gamma": self.nb.gamma,
"V": self.nb.V,
}
}, f)
def load(self, path: str):
"""加载模型"""
with open(path, "rb") as f:
obj = pickle.load(f)
# 还原候选文本
self._candidates = obj["candidates"]
# 还原situations兼容旧版本
self._situations = obj.get("situations", {})
# 还原朴素贝叶斯模型
self.nb.cls_counts = obj["nb"]["cls_counts"]
self.nb.token_counts = defaultdict_dict(obj["nb"]["token_counts"])
self.nb.alpha = obj["nb"]["alpha"]
self.nb.beta = obj["nb"]["beta"]
self.nb.gamma = obj["nb"]["gamma"]
self.nb.V = obj["nb"]["V"]
self.nb._logZ.clear()
def defaultdict_dict(d: Dict[str, Dict[str, float]]):
from collections import defaultdict
outer = defaultdict(lambda: defaultdict(float))
for k, inner in d.items():
outer[k].update(inner)
return outer

View File

@@ -0,0 +1,60 @@
import math
from typing import Dict, List
from collections import defaultdict, Counter
class OnlineNaiveBayes:
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, vocab_size: int = 200000):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.V = vocab_size
self.cls_counts: Dict[str, float] = defaultdict(float) # cid -> total token count
self.token_counts: Dict[str, Dict[str, float]] = defaultdict(lambda: defaultdict(float)) # cid -> term -> count
self._logZ: Dict[str, float] = {} # cache log(∑counts + Vα)
def _invalidate(self, cid: str):
if cid in self._logZ:
del self._logZ[cid]
def _logZ_c(self, cid: str) -> float:
if cid not in self._logZ:
Z = self.cls_counts[cid] + self.V * self.alpha
self._logZ[cid] = math.log(max(Z, 1e-12))
return self._logZ[cid]
def score_batch(self, tf: Counter, cids: List[str]) -> Dict[str, float]:
total_cls = sum(self.cls_counts.values())
n_cls = max(1, len(self.cls_counts))
denom_prior = math.log(total_cls + self.beta * n_cls)
out: Dict[str, float] = {}
for cid in cids:
prior = math.log(self.cls_counts[cid] + self.beta) - denom_prior
s = prior
logZ = self._logZ_c(cid)
tc = self.token_counts[cid]
for term, qtf in tf.items():
num = tc.get(term, 0.0) + self.alpha
s += qtf * (math.log(num) - logZ)
out[cid] = s
return out
def update_positive(self, tf: Counter, cid: str):
inc = 0.0
tc = self.token_counts[cid]
for term, c in tf.items():
tc[term] += float(c)
inc += float(c)
self.cls_counts[cid] += inc
self._invalidate(cid)
def decay(self, factor: float = None):
g = self.gamma if factor is None else factor
if g >= 1.0:
return
for cid in list(self.cls_counts.keys()):
self.cls_counts[cid] *= g
for term in list(self.token_counts[cid].keys()):
self.token_counts[cid][term] *= g
self._invalidate(cid)

View File

@@ -0,0 +1,28 @@
import re
from typing import List, Optional, Set
try:
import jieba
_HAS_JIEBA = True
except Exception:
_HAS_JIEBA = False
_WORD_RE = re.compile(r"[A-Za-z0-9_]+")
def simple_en_tokenize(text: str) -> List[str]:
return _WORD_RE.findall(text.lower())
class Tokenizer:
def __init__(self, stopwords: Optional[Set[str]] = None, use_jieba: bool = True):
self.stopwords = stopwords or set()
self.use_jieba = use_jieba and _HAS_JIEBA
def tokenize(self, text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
if self.use_jieba:
toks = [t.strip().lower() for t in jieba.cut(text) if t.strip()]
else:
toks = simple_en_tokenize(text)
return [t for t in toks if t not in self.stopwords]

View File

@@ -0,0 +1,628 @@
"""
多聊天室表达风格学习系统
支持为每个chat_id维护独立的表达模型学习从up_content到style的映射
"""
import os
import pickle
import traceback
from typing import Dict, List, Optional, Tuple
from collections import defaultdict
import asyncio
from src.common.logger import get_logger
from .expressor_model.model import ExpressorModel
logger = get_logger("style_learner")
class StyleLearner:
"""
单个聊天室的表达风格学习器
学习从up_content到style的映射关系
支持动态管理风格集合最多2000个
"""
def __init__(self, chat_id: str, model_config: Optional[Dict] = None):
self.chat_id = chat_id
self.model_config = model_config or {
"alpha": 0.5,
"beta": 0.5,
"gamma": 0.99, # 衰减因子,支持遗忘
"vocab_size": 200000,
"use_jieba": True
}
# 初始化表达模型
self.expressor = ExpressorModel(**self.model_config)
# 动态风格管理
self.max_styles = 2000 # 每个chat_id最多2000个风格
self.style_to_id: Dict[str, str] = {} # style文本 -> style_id
self.id_to_style: Dict[str, str] = {} # style_id -> style文本
self.id_to_situation: Dict[str, str] = {} # style_id -> situation文本
self.next_style_id = 0 # 下一个可用的style_id
# 学习统计
self.learning_stats = {
"total_samples": 0,
"style_counts": defaultdict(int),
"last_update": None,
"style_usage_frequency": defaultdict(int) # 风格使用频率
}
def add_style(self, style: str, situation: str = None) -> bool:
"""
动态添加一个新的风格
Args:
style: 风格文本
situation: 对应的situation文本可选
Returns:
bool: 添加是否成功
"""
try:
# 检查是否已存在
if style in self.style_to_id:
logger.debug(f"[{self.chat_id}] 风格 '{style}' 已存在")
return True
# 检查是否超过最大限制
if len(self.style_to_id) >= self.max_styles:
logger.warning(f"[{self.chat_id}] 已达到最大风格数量限制 ({self.max_styles})")
return False
# 生成新的style_id
style_id = f"style_{self.next_style_id}"
self.next_style_id += 1
# 添加到映射
self.style_to_id[style] = style_id
self.id_to_style[style_id] = style
if situation:
self.id_to_situation[style_id] = situation
# 添加到expressor模型
self.expressor.add_candidate(style_id, style, situation)
logger.info(f"[{self.chat_id}] 已添加风格: '{style}' (ID: {style_id})" +
(f", situation: '{situation}'" if situation else ""))
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 添加风格失败: {e}")
return False
def remove_style(self, style: str) -> bool:
"""
删除一个风格
Args:
style: 要删除的风格文本
Returns:
bool: 删除是否成功
"""
try:
if style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 风格 '{style}' 不存在")
return False
style_id = self.style_to_id[style]
# 从映射中删除
del self.style_to_id[style]
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从expressor模型中删除通过重新构建
self._rebuild_expressor()
logger.info(f"[{self.chat_id}] 已删除风格: '{style}' (ID: {style_id})")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 删除风格失败: {e}")
return False
def update_style(self, old_style: str, new_style: str) -> bool:
"""
更新一个风格
Args:
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
try:
if old_style not in self.style_to_id:
logger.warning(f"[{self.chat_id}] 原风格 '{old_style}' 不存在")
return False
if new_style in self.style_to_id and new_style != old_style:
logger.warning(f"[{self.chat_id}] 新风格 '{new_style}' 已存在")
return False
style_id = self.style_to_id[old_style]
# 更新映射
del self.style_to_id[old_style]
self.style_to_id[new_style] = style_id
self.id_to_style[style_id] = new_style
# 更新expressor模型保留原有的situation
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, new_style, situation)
logger.info(f"[{self.chat_id}] 已更新风格: '{old_style}' -> '{new_style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 更新风格失败: {e}")
return False
def add_styles_batch(self, styles: List[str], situations: List[str] = None) -> int:
"""
批量添加风格
Args:
styles: 风格文本列表
situations: 对应的situation文本列表可选
Returns:
int: 成功添加的数量
"""
success_count = 0
for i, style in enumerate(styles):
situation = situations[i] if situations and i < len(situations) else None
if self.add_style(style, situation):
success_count += 1
logger.info(f"[{self.chat_id}] 批量添加风格: {success_count}/{len(styles)} 成功")
return success_count
def get_all_styles(self) -> List[str]:
"""获取所有已注册的风格"""
return list(self.style_to_id.keys())
def get_style_count(self) -> int:
"""获取当前风格数量"""
return len(self.style_to_id)
def get_situation(self, style: str) -> Optional[str]:
"""
获取风格对应的situation
Args:
style: 风格文本
Returns:
Optional[str]: 对应的situation如果不存在则返回None
"""
if style not in self.style_to_id:
return None
style_id = self.style_to_id[style]
return self.id_to_situation.get(style_id)
def get_style_info(self, style: str) -> Tuple[Optional[str], Optional[str]]:
"""
获取风格的完整信息
Args:
style: 风格文本
Returns:
Tuple[Optional[str], Optional[str]]: (style_id, situation)
"""
if style not in self.style_to_id:
return None, None
style_id = self.style_to_id[style]
situation = self.id_to_situation.get(style_id)
return style_id, situation
def get_all_style_info(self) -> Dict[str, Tuple[str, Optional[str]]]:
"""
获取所有风格的完整信息
Returns:
Dict[str, Tuple[str, Optional[str]]]: {style: (style_id, situation)}
"""
result = {}
for style, style_id in self.style_to_id.items():
situation = self.id_to_situation.get(style_id)
result[style] = (style_id, situation)
return result
def _rebuild_expressor(self):
"""重新构建expressor模型删除风格后使用"""
try:
# 重新创建expressor
self.expressor = ExpressorModel(**self.model_config)
# 重新添加所有风格和situation
for style_id, style_text in self.id_to_style.items():
situation = self.id_to_situation.get(style_id)
self.expressor.add_candidate(style_id, style_text, situation)
logger.debug(f"[{self.chat_id}] 已重新构建expressor模型")
except Exception as e:
logger.error(f"[{self.chat_id}] 重新构建expressor失败: {e}")
def learn_mapping(self, up_content: str, style: str) -> bool:
"""
学习一个up_content到style的映射
如果style不存在会自动添加
Args:
up_content: 输入内容
style: 对应的style文本
Returns:
bool: 学习是否成功
"""
try:
# 如果style不存在先添加它
if style not in self.style_to_id:
if not self.add_style(style):
logger.warning(f"[{self.chat_id}] 无法添加风格 '{style}',学习失败")
return False
# 获取style_id
style_id = self.style_to_id[style]
# 使用正反馈学习
self.expressor.update_positive(up_content, style_id)
# 更新统计
self.learning_stats["total_samples"] += 1
self.learning_stats["style_counts"][style_id] += 1
self.learning_stats["style_usage_frequency"][style] += 1
self.learning_stats["last_update"] = asyncio.get_event_loop().time()
logger.debug(f"[{self.chat_id}] 学习映射: '{up_content}' -> '{style}'")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 学习映射失败: {e}")
traceback.print_exc()
return False
def predict_style(self, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
根据up_content预测最合适的style
Args:
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style文本, 所有候选的分数]
"""
try:
best_style_id, scores = self.expressor.predict(up_content, k=top_k)
if best_style_id is None:
return None, {}
# 将style_id转换为style文本
best_style = self.id_to_style.get(best_style_id)
# 转换所有分数
style_scores = {}
for sid, score in scores.items():
style_text = self.id_to_style.get(sid)
if style_text:
style_scores[style_text] = score
return best_style, style_scores
except Exception as e:
logger.error(f"[{self.chat_id}] 预测style失败: {e}")
traceback.print_exc()
return None, {}
def decay_learning(self, factor: Optional[float] = None) -> None:
"""
对学习到的知识进行衰减(遗忘)
Args:
factor: 衰减因子None则使用配置中的gamma
"""
self.expressor.decay(factor)
logger.debug(f"[{self.chat_id}] 执行知识衰减")
def get_stats(self) -> Dict:
"""获取学习统计信息"""
return {
"chat_id": self.chat_id,
"total_samples": self.learning_stats["total_samples"],
"style_count": len(self.style_to_id),
"max_styles": self.max_styles,
"style_counts": dict(self.learning_stats["style_counts"]),
"style_usage_frequency": dict(self.learning_stats["style_usage_frequency"]),
"last_update": self.learning_stats["last_update"],
"all_styles": list(self.style_to_id.keys())
}
def save(self, base_path: str) -> bool:
"""
保存模型到文件
Args:
base_path: 基础路径,实际文件为 {base_path}/{chat_id}_style_model.pkl
"""
try:
os.makedirs(base_path, exist_ok=True)
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
# 保存模型和统计信息
save_data = {
"model_config": self.model_config,
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
"id_to_situation": self.id_to_situation,
"next_style_id": self.next_style_id,
"max_styles": self.max_styles,
"learning_stats": self.learning_stats
}
# 先保存expressor模型
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
self.expressor.save(expressor_path)
# 保存其他数据
with open(file_path, "wb") as f:
pickle.dump(save_data, f)
logger.info(f"[{self.chat_id}] 模型已保存到 {file_path}")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 保存模型失败: {e}")
return False
def load(self, base_path: str) -> bool:
"""
从文件加载模型
Args:
base_path: 基础路径
"""
try:
file_path = os.path.join(base_path, f"{self.chat_id}_style_model.pkl")
expressor_path = os.path.join(base_path, f"{self.chat_id}_expressor.pkl")
if not os.path.exists(file_path) or not os.path.exists(expressor_path):
logger.warning(f"[{self.chat_id}] 模型文件不存在,将使用默认配置")
return False
# 加载其他数据
with open(file_path, "rb") as f:
save_data = pickle.load(f)
# 恢复配置和状态
self.model_config = save_data["model_config"]
self.style_to_id = save_data["style_to_id"]
self.id_to_style = save_data["id_to_style"]
self.id_to_situation = save_data.get("id_to_situation", {}) # 兼容旧版本
self.next_style_id = save_data["next_style_id"]
self.max_styles = save_data.get("max_styles", 2000)
self.learning_stats = save_data["learning_stats"]
# 重新创建expressor并加载
self.expressor = ExpressorModel(**self.model_config)
self.expressor.load(expressor_path)
logger.info(f"[{self.chat_id}] 模型已从 {file_path} 加载")
return True
except Exception as e:
logger.error(f"[{self.chat_id}] 加载模型失败: {e}")
return False
class StyleLearnerManager:
"""
多聊天室表达风格学习管理器
为每个chat_id维护独立的StyleLearner实例
每个chat_id可以动态管理自己的风格集合最多2000个
"""
def __init__(self, model_save_path: str = "data/style_models"):
self.model_save_path = model_save_path
self.learners: Dict[str, StyleLearner] = {}
# 自动保存配置
self.auto_save_interval = 300 # 5分钟
self._auto_save_task: Optional[asyncio.Task] = None
logger.info("StyleLearnerManager 已初始化")
def get_learner(self, chat_id: str, model_config: Optional[Dict] = None) -> StyleLearner:
"""
获取或创建指定chat_id的学习器
Args:
chat_id: 聊天室ID
model_config: 模型配置None则使用默认配置
Returns:
StyleLearner实例
"""
if chat_id not in self.learners:
# 创建新的学习器
learner = StyleLearner(chat_id, model_config)
# 尝试加载已保存的模型
learner.load(self.model_save_path)
self.learners[chat_id] = learner
logger.info(f"为 chat_id={chat_id} 创建新的StyleLearner")
return self.learners[chat_id]
def add_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id添加风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 添加是否成功
"""
learner = self.get_learner(chat_id)
return learner.add_style(style)
def remove_style(self, chat_id: str, style: str) -> bool:
"""
为指定chat_id删除风格
Args:
chat_id: 聊天室ID
style: 风格文本
Returns:
bool: 删除是否成功
"""
learner = self.get_learner(chat_id)
return learner.remove_style(style)
def update_style(self, chat_id: str, old_style: str, new_style: str) -> bool:
"""
为指定chat_id更新风格
Args:
chat_id: 聊天室ID
old_style: 原风格文本
new_style: 新风格文本
Returns:
bool: 更新是否成功
"""
learner = self.get_learner(chat_id)
return learner.update_style(old_style, new_style)
def get_chat_styles(self, chat_id: str) -> List[str]:
"""
获取指定chat_id的所有风格
Args:
chat_id: 聊天室ID
Returns:
List[str]: 风格列表
"""
learner = self.get_learner(chat_id)
return learner.get_all_styles()
def learn_mapping(self, chat_id: str, up_content: str, style: str) -> bool:
"""
学习一个映射关系
Args:
chat_id: 聊天室ID
up_content: 输入内容
style: 对应的style
Returns:
bool: 学习是否成功
"""
learner = self.get_learner(chat_id)
return learner.learn_mapping(up_content, style)
def predict_style(self, chat_id: str, up_content: str, top_k: int = 5) -> Tuple[Optional[str], Dict[str, float]]:
"""
预测最合适的style
Args:
chat_id: 聊天室ID
up_content: 输入内容
top_k: 返回前k个候选
Returns:
Tuple[最佳style, 所有候选分数]
"""
learner = self.get_learner(chat_id)
return learner.predict_style(up_content, top_k)
def decay_all_learners(self, factor: Optional[float] = None) -> None:
"""
对所有学习器执行衰减
Args:
factor: 衰减因子
"""
for learner in self.learners.values():
learner.decay_learning(factor)
logger.info("已对所有学习器执行衰减")
def get_all_stats(self) -> Dict[str, Dict]:
"""获取所有学习器的统计信息"""
return {chat_id: learner.get_stats() for chat_id, learner in self.learners.items()}
def save_all_models(self) -> bool:
"""保存所有模型"""
success_count = 0
for learner in self.learners.values():
if learner.save(self.model_save_path):
success_count += 1
logger.info(f"已保存 {success_count}/{len(self.learners)} 个模型")
return success_count == len(self.learners)
def load_all_models(self) -> int:
"""加载所有已保存的模型"""
if not os.path.exists(self.model_save_path):
return 0
loaded_count = 0
for filename in os.listdir(self.model_save_path):
if filename.endswith("_style_model.pkl"):
chat_id = filename.replace("_style_model.pkl", "")
learner = StyleLearner(chat_id)
if learner.load(self.model_save_path):
self.learners[chat_id] = learner
loaded_count += 1
logger.info(f"已加载 {loaded_count} 个模型")
return loaded_count
async def start_auto_save(self) -> None:
"""启动自动保存任务"""
if self._auto_save_task is None or self._auto_save_task.done():
self._auto_save_task = asyncio.create_task(self._auto_save_loop())
logger.info("已启动自动保存任务")
async def stop_auto_save(self) -> None:
"""停止自动保存任务"""
if self._auto_save_task and not self._auto_save_task.done():
self._auto_save_task.cancel()
try:
await self._auto_save_task
except asyncio.CancelledError:
pass
logger.info("已停止自动保存任务")
async def _auto_save_loop(self) -> None:
"""自动保存循环"""
while True:
try:
await asyncio.sleep(self.auto_save_interval)
self.save_all_models()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"自动保存失败: {e}")
# 全局管理器实例
style_learner_manager = StyleLearnerManager()