better:优化了表达模型的预测,和表达方式的学习逻辑
This commit is contained in:
@@ -1,11 +1,9 @@
|
||||
import time
|
||||
import random
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
import difflib
|
||||
from src.common.logger import get_logger
|
||||
@@ -148,7 +146,7 @@ class ExpressionLearner:
|
||||
|
||||
return True
|
||||
|
||||
async def trigger_learning_for_chat(self) -> bool:
|
||||
async def trigger_learning_for_chat(self):
|
||||
"""
|
||||
为指定聊天流触发学习
|
||||
|
||||
@@ -159,11 +157,10 @@ class ExpressionLearner:
|
||||
bool: 是否成功触发学习
|
||||
"""
|
||||
if not self.should_trigger_learning():
|
||||
return False
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"为聊天流 {self.chat_name} 触发表达学习")
|
||||
|
||||
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
|
||||
# 学习语言风格
|
||||
learnt_style = await self.learn_and_store(num=25)
|
||||
|
||||
@@ -172,15 +169,13 @@ class ExpressionLearner:
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -188,127 +183,87 @@ class ExpressionLearner:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
"""
|
||||
res = await self.learn_expression(num)
|
||||
learnt_expressions = await self.learn_expression(num)
|
||||
|
||||
if res is None:
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
learnt_expressions = res
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
_chat_id,
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
for (
|
||||
chat_id,
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append(
|
||||
{
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"context": context,
|
||||
"up_content": up_content,
|
||||
}
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
trained_chat_ids = set() # 记录已训练的聊天室
|
||||
|
||||
for chat_id, expr_list in chat_dict.items():
|
||||
for new_expr in expr_list:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == "style")
|
||||
& (Expression.situation == new_expr["situation"])
|
||||
& (Expression.style == new_expr["style"])
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
continue
|
||||
else:
|
||||
Expression.create(
|
||||
situation=situation,
|
||||
style=style,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
# 50%概率替换内容
|
||||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.up_content = new_expr["up_content"]
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
else:
|
||||
Expression.create(
|
||||
situation=new_expr["situation"],
|
||||
style=new_expr["style"],
|
||||
last_active_time=current_time,
|
||||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
up_content=new_expr["up_content"],
|
||||
)
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
# 获取 learner 实例
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 先添加风格和对应的 situation(如果存在)
|
||||
if new_expr.get("situation"):
|
||||
learner.add_style(new_expr["style"], new_expr["situation"])
|
||||
else:
|
||||
learner.add_style(new_expr["style"])
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
chat_id,
|
||||
new_expr["up_content"],
|
||||
new_expr["style"]
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}" +
|
||||
(f" (situation: {new_expr['situation']})" if new_expr.get("situation") else ""))
|
||||
trained_chat_ids.add(chat_id)
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {chat_id} - {new_expr['up_content']} -> {new_expr['style']}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {chat_id} - {e}")
|
||||
has_new_expressions = True
|
||||
|
||||
# 限制最大数量
|
||||
# exprs = list(
|
||||
# Expression.select()
|
||||
# .where((Expression.chat_id == chat_id) & (Expression.type == "style"))
|
||||
# .order_by(Expression.last_active_time.asc())
|
||||
# )
|
||||
# if len(exprs) > MAX_EXPRESSION_COUNT:
|
||||
# 删除最久未活跃的多余表达方式
|
||||
# for expr in exprs[: len(exprs) - MAX_EXPRESSION_COUNT]:
|
||||
# expr.delete_instance()
|
||||
|
||||
# 保存训练好的 style_learner 模型
|
||||
if trained_chat_ids:
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
logger.info(f"开始保存 {len(trained_chat_ids)} 个聊天室的 StyleLearner 模型...")
|
||||
save_success = style_learner_manager.save_all_models()
|
||||
learner.add_style(style, situation)
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,涉及聊天室: {list(trained_chat_ids)}")
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning("StyleLearner 模型保存失败")
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
@@ -415,15 +370,12 @@ class ExpressionLearner:
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, List[str], str]]]:
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
num: 学习数量
|
||||
"""
|
||||
type_str = "语言风格"
|
||||
prompt = "learn_style_prompt"
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 获取上次学习之后的消息
|
||||
@@ -436,14 +388,14 @@ class ExpressionLearner:
|
||||
# print(random_msg)
|
||||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
# 转化成str
|
||||
_chat_id: str = random_msg[0].chat_id
|
||||
# random_msg_str: str = build_readable_messages(random_msg, timestamp_mode="normal")
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
"learn_style_prompt",
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
@@ -453,20 +405,18 @@ class ExpressionLearner:
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习{type_str}失败: {e}")
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
|
||||
print(f"matched_expressions: {matched_expressions}")
|
||||
|
||||
# 为每条消息构建与 build_bare_messages 相同规则的精简文本列表,保留到原消息索引的映射
|
||||
# 这里有待斟酌,需要进一步处理图片和表情包
|
||||
bare_lines: List[Tuple[int, str]] = [] # (original_index, bare_content)
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
@@ -479,7 +429,6 @@ class ExpressionLearner:
|
||||
content = content.strip()
|
||||
if content:
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
@@ -503,11 +452,7 @@ class ExpressionLearner:
|
||||
if not filtered_with_up:
|
||||
return None
|
||||
|
||||
results: List[Tuple[str, str, str, str]] = []
|
||||
for (situation, style, context, up_content) in filtered_with_up:
|
||||
results.append((self.chat_id, situation, style, context, up_content))
|
||||
|
||||
return results
|
||||
return filtered_with_up
|
||||
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
|
||||
Reference in New Issue
Block a user