Ruff fix
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
@@ -158,8 +157,6 @@ class ExpressionLearner:
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
|
||||
|
||||
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
@@ -169,7 +166,7 @@ class ExpressionLearner:
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
@@ -186,7 +183,7 @@ class ExpressionLearner:
|
||||
# 存储到数据库 Expression 表并训练 style_learner
|
||||
has_new_expressions = False # 记录是否有新的表达方式
|
||||
learner = style_learner_manager.get_learner(self.chat_id) # 获取 learner 实例
|
||||
|
||||
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
@@ -195,9 +192,7 @@ class ExpressionLearner:
|
||||
) in learnt_expressions:
|
||||
# 查找是否已存在相似表达方式
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == self.chat_id)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
# 表达方式完全相同,只更新时间戳
|
||||
@@ -216,39 +211,37 @@ class ExpressionLearner:
|
||||
up_content=up_content,
|
||||
)
|
||||
has_new_expressions = True
|
||||
|
||||
|
||||
# 训练 style_learner(up_content 和 style 必定存在)
|
||||
try:
|
||||
learner.add_style(style, situation)
|
||||
|
||||
|
||||
# 学习映射关系
|
||||
success = style_learner_manager.learn_mapping(
|
||||
self.chat_id,
|
||||
up_content,
|
||||
style
|
||||
)
|
||||
success = style_learner_manager.learn_mapping(self.chat_id, up_content, style)
|
||||
if success:
|
||||
logger.debug(f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}" + (f" (situation: {situation})" if situation else ""))
|
||||
logger.debug(
|
||||
f"StyleLearner学习成功: {self.chat_id} - {up_content} -> {style}"
|
||||
+ (f" (situation: {situation})" if situation else "")
|
||||
)
|
||||
else:
|
||||
logger.warning(f"StyleLearner学习失败: {self.chat_id} - {up_content} -> {style}")
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner学习异常: {self.chat_id} - {e}")
|
||||
|
||||
|
||||
|
||||
# 保存当前聊天室的 style_learner 模型
|
||||
if has_new_expressions:
|
||||
try:
|
||||
logger.info(f"开始保存聊天室 {self.chat_id} 的 StyleLearner 模型...")
|
||||
save_success = learner.save(style_learner_manager.model_save_path)
|
||||
|
||||
|
||||
if save_success:
|
||||
logger.info(f"StyleLearner 模型保存成功,聊天室: {self.chat_id}")
|
||||
else:
|
||||
logger.warning(f"StyleLearner 模型保存失败,聊天室: {self.chat_id}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"StyleLearner 模型保存异常: {e}")
|
||||
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
@@ -334,7 +327,7 @@ class ExpressionLearner:
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
|
||||
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"match_responses 内容: {match_responses}")
|
||||
|
||||
@@ -344,12 +337,12 @@ class ExpressionLearner:
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
@@ -367,9 +360,7 @@ class ExpressionLearner:
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(
|
||||
self, num: int = 10
|
||||
) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
@@ -409,7 +400,6 @@ class ExpressionLearner:
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
@@ -426,17 +416,17 @@ class ExpressionLearner:
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
continue
|
||||
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
continue
|
||||
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
@@ -449,7 +439,6 @@ class ExpressionLearner:
|
||||
|
||||
return filtered_with_up
|
||||
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
@@ -483,21 +472,21 @@ class ExpressionLearner:
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user