better:优化了表达模型的预测,和表达方式的学习逻辑

This commit is contained in:
SengokuCola
2025-10-11 23:44:52 +08:00
parent 4a074ec374
commit d073a215e3
8 changed files with 113 additions and 410 deletions

View File

@@ -1,11 +1,9 @@
import time
import random
import json
import os
import re
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
from typing import List, Optional, Tuple
import traceback
import difflib
from src.common.logger import get_logger
@@ -148,7 +146,7 @@ class ExpressionLearner:
return True
async def trigger_learning_for_chat(self) -> bool:
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
@@ -159,11 +157,10 @@ class ExpressionLearner:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return False
return
try:
logger.info(f"聊天流 {self.chat_name} 触发表达学习")
logger.info(f"聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
@@ -172,15 +169,13 @@ class ExpressionLearner:
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
return True
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
return False
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return False
return
@@ -188,127 +183,87 @@ class ExpressionLearner:
"""
学习并存储表达方式
"""
res = await self.learn_expression(num)
learnt_expressions = await self.learn_expression(num)
if res is None:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
learnt_expressions = res
# 展示学到的表达方式
learnt_expressions_str = ""
for (
_chat_id,
situation,
style,
_context,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
chat_id,
situation,
style,
context,
up_content,
) in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append(
{
"situation": situation,
"style": style,
"context": context,
"up_content": up_content,
}
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
)
current_time = time.time()
# 存储到数据库 Expression 表并训练 style_learner
trained_chat_ids = set() # 记录已训练的聊天室
for chat_id, expr_list in chat_dict.items():
for new_expr in expr_list:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == "style")
& (Expression.situation == new_expr["situation"])
& (Expression.style == new_expr["style"])
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
continue
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
if query.exists():
expr_obj = query.get()
# 50%概率替换内容
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.up_content = new_expr["up_content"]
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=new_expr["situation"],
style=new_expr["style"],
last_active_time=current_time,
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
up_content=new_expr["up_content"],
)
# 训练 style_learnerup_content 和 style 必定存在)
try:
# 获取 learner 实例
learner = style_learner_manager.get_learner(chat_id)
# 先添加风格和对应的 situation如果存在
if new_expr.get("situation"):
learner.add_style(new_expr["style"], new_expr["situation"])
else:
learner.add_style(new_expr["style"])
# 学习映射关系
success = style_learner_manager.learn_mapping(
chat_id,
new_expr["up_content"],
new_expr["style"]
)
if success:
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
trained_chat_ids.add(chat_id)
else:
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
has_new_expressions = True
# 限制最大数量
# exprs = list(
# Expression.select()
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
# .order_by(Expression.last_active_time.asc())
# )
# if len(exprs) > MAX_EXPRESSION_COUNT:
# 删除最久未活跃的多余表达方式
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
# expr.delete_instance()
# 保存训练好的 style_learner 模型
if trained_chat_ids:
# 训练 style_learnerup_content 和 style 必定存在)
try:
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
save_success = style_learner_manager.save_all_models()
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning("StyleLearner 模型保存失败")
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
@@ -415,15 +370,12 @@ class ExpressionLearner:
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
type_str = "语言风格"
prompt = "learn_style_prompt"
current_time = time.time()
# 获取上次学习之后的消息
@@ -436,14 +388,14 @@ class ExpressionLearner:
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 转化成str
_chat_id: str = random_msg[0].chat_id
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
prompt,
"learn_style_prompt",
chat_str=random_msg_str,
)
@@ -453,20 +405,18 @@ class ExpressionLearner:
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
print(f"matched_expressions: {matched_expressions}")
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
# 这里有待斟酌,需要进一步处理图片和表情包
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
pic_pattern = r"\[picid:[^\]]+\]"
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
@@ -479,7 +429,6 @@ class ExpressionLearner:
content = content.strip()
if content:
bare_lines.append((idx, content))
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
@@ -503,11 +452,7 @@ class ExpressionLearner:
if not filtered_with_up:
return None
results: List[Tuple[str, str, str, str]] = []
for (situation, style, context, up_content) in filtered_with_up:
results.append((self.chat_id, situation, style, context, up_content))
return results
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: