This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -1,7 +1,6 @@
import time
import json
import os
from datetime import datetime
from typing import List, Optional, Tuple
import traceback
from src.common.logger import get_logger
@@ -158,8 +157,6 @@ class ExpressionLearner:
traceback.print_exc()
return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
@@ -169,7 +166,7 @@ class ExpressionLearner:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
@@ -186,7 +183,7 @@ class ExpressionLearner:
# 存储到数据库 Expression 表并训练 style_learner
has_new_expressions = False # 记录是否有新的表达方式
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
for (
situation,
style,
@@ -195,9 +192,7 @@ class ExpressionLearner:
) in learnt_expressions:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
)
if query.exists():
# 表达方式完全相同,只更新时间戳
@@ -216,39 +211,37 @@ class ExpressionLearner:
up_content=up_content,
)
has_new_expressions = True
# 训练 style_learnerup_content 和 style 必定存在)
try:
learner.add_style(style, situation)
# 学习映射关系
success = style_learner_manager.learn_mapping(
self.chat_id,
up_content,
style
)
success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
if success:
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
logger.debug(
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
+ (f" (situation: {situation})" if situation else "")
)
else:
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
except Exception as e:
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
# 保存当前聊天室的 style_learner 模型
if has_new_expressions:
try:
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
save_success = learner.save(style_learner_manager.model_save_path)
if save_success:
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
else:
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
except Exception as e:
logger.error(f"StyleLearner 模型保存异常: {e}")
return learnt_expressions
async def match_expression_context(
@@ -334,7 +327,7 @@ class ExpressionLearner:
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}")
@@ -344,12 +337,12 @@ class ExpressionLearner:
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
@@ -367,9 +360,7 @@ class ExpressionLearner:
return matched_expressions
async def learn_expression(
self, num: int = 10
) -> Optional[List[Tuple[str, str, str, str]]]:
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
@@ -409,7 +400,6 @@ class ExpressionLearner:
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
@@ -426,17 +416,17 @@ class ExpressionLearner:
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
@@ -449,7 +439,6 @@ class ExpressionLearner:
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
@@ -483,21 +472,21 @@ class ExpressionLearner:
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表,保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines