feat:新增自动表达优化功能,优化表达方式的提取

This commit is contained in:
SengokuCola
2025-12-27 17:20:11 +08:00
parent ba9b9d26a2
commit 99665e7918
14 changed files with 1177 additions and 837 deletions

View File

@@ -18,10 +18,13 @@ from src.bw_learner.learner_utils import (
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
calculate_style_similarity,
calculate_similarity,
parse_expression_response,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
from src.bw_learner.expression_auto_check_task import (
single_expression_check,
)
# MAX_EXPRESSION_COUNT = 300
@@ -91,6 +94,7 @@ class ExpressionLearner:
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.summary"
)
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
@@ -136,11 +140,10 @@ class ExpressionLearner:
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
expressions, jargon_entries = parse_expression_response(response)
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
if len(expressions) > 20:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = []
@@ -155,7 +158,7 @@ class ExpressionLearner:
# 如果没有表达方式,直接返回
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
logger.info("解析后没有可用的表达方式")
return []
logger.info(f"学习的prompt: {prompt}")
@@ -163,9 +166,60 @@ class ExpressionLearner:
logger.info(f"学习的jargon_entries: {jargon_entries}")
logger.info(f"学习的response: {response}")
# 直接根据 source_id 在 random_msg 中溯源,获取 context
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
learnt_expressions = self._filter_expressions(expressions, random_msg)
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (situation,style) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (situation,style) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
current_time=current_time,
)
return learnt_expressions
def _filter_expressions(
self,
expressions: List[Tuple[str, str, str]],
messages: List[Any],
) -> List[Tuple[str, str, str]]:
"""
过滤表达方式,移除不符合条件的条目
Args:
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
messages: 原始消息列表,用于溯源和验证
Returns:
过滤后的表达方式列表,每个元素是 (situation, style, context)
"""
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
for situation, style, source_id in expressions:
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
@@ -173,12 +227,12 @@ class ExpressionLearner:
continue
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
if line_index < 0 or line_index >= len(random_msg):
if line_index < 0 or line_index >= len(messages):
# 超出范围,跳过
continue
# 当前行的原始内容
current_msg = random_msg[line_index]
current_msg = messages[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
@@ -195,251 +249,53 @@ class ExpressionLearner:
)
continue
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
return learnt_expressions
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
期望的 JSON 结构:
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
# 尝试提取 ```json 代码块
json_block_pattern = r"```json\s*(.*?)\s*```"
match = re.search(json_block_pattern, raw, re.DOTALL)
if match:
raw = match.group(1).strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
try:
# 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as parse_error:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
result.append('\\"')
elif char == '"': # 中文右引号 U+201D
result.append('\\"')
else:
result.append(char)
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
else:
repaired = repair_json(fixed_raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as fix_error:
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return [], []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
for item in parsed_list:
if not isinstance(item, dict):
# 过滤掉 style 与机器人名称/昵称重复的表达
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() in banned_casefold:
logger.debug(
f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
# 过滤掉包含 "表情:" 或 "表情:" 的内容
if "表情:" in (situation or "") or "表情:" in (situation or "") or \
"表情:" in (style or "") or "表情:" in (style or "") or \
"表情:" in context or "表情:" in context:
logger.info(
f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
elif item.get("content"):
# 黑话条目(有 content 字段)
content = str(item.get("content", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
# 过滤掉包含 "[图片" 的内容
if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context:
logger.info(
f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
return expressions, jargon_entries
filtered_expressions.append((situation, style))
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str, str]] = []
removed_count = 0
for situation, style, source_id in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style, source_id))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
return filtered_expressions
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
# 第一层:检查是否有完全一致的 style检查 style 字段和 style_list
expr_obj = await self._find_exact_style_match(style)
# 检查是否有相似的 situation相似度 >= 0.75,检查 content_list
# 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理
expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75)
if expr_obj:
# 找到完全匹配的 style合并到现有记录使用 LLM 总结
# 根据相似度决定是否使用 LLM 总结
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
use_llm_summary = similarity < 1.0
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=False,
)
return
# 第二层:检查是否有相似的 style相似度 >= 0.75,检查 style 字段和 style_list
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
if similar_expr_obj:
# 找到相似的 style合并到现有记录使用 LLM 总结)
await self._update_existing_expression(
expr_obj=similar_expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=True,
use_llm_summary=use_llm_summary,
)
return
@@ -447,7 +303,6 @@ class ExpressionLearner:
await self._create_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
@@ -455,7 +310,6 @@ class ExpressionLearner:
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
content_list = [situation]
@@ -466,26 +320,22 @@ class ExpressionLearner:
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
style_list=None, # 新记录初始时 style_list 为空
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
style: str,
context: str,
current_time: float,
use_llm_summary: bool = True,
) -> None:
"""
更新现有 Expression 记录style 完全匹配或相似的情况)
将新的 situation 添加到 content_list将新的 style 添加到 style_list如果不同
更新现有 Expression 记录situation 完全匹配或相似的情况)
将新的 situation 添加到 content_list不合并 style
Args:
use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False相似匹配时为 True
@@ -495,43 +345,24 @@ class ExpressionLearner:
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
# 更新 style_list如果 style 不同,添加到 style_list
style_list = self._parse_style_list(expr_obj.style_list)
# 将原有的 style 也加入 style_list如果还没有的话
if expr_obj.style and expr_obj.style not in style_list:
style_list.append(expr_obj.style)
# 如果新的 style 不在 style_list 中,添加它
if style not in style_list:
style_list.append(style)
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
# 更新其他字段
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.checked = False # count 增加时重置 checked 为 False
expr_obj.last_active_time = current_time
expr_obj.context = context
if use_llm_summary:
# 相似匹配时,使用 LLM 重新组合 situation 和 style
# 相似匹配时,使用 LLM 重新组合 situation
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
new_style = await self._compose_style_text(
style_list=style_list,
count=expr_obj.count,
fallback=expr_obj.style or style,
)
expr_obj.style = new_style
else:
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
# 只更新 content_list 和 style_list
pass
expr_obj.save()
# count 增加后,立即进行一次检查
await self._check_expression_immediately(expr_obj)
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
@@ -541,49 +372,19 @@ class ExpressionLearner:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
async def _find_similar_situation_expression(self, situation: str, similarity_threshold: float = 0.75) -> Tuple[Optional[Expression], float]:
"""
查找具有完全匹配 style 的 Expression 记录
检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
查找具有相似 situation 的 Expression 记录
检查 content_list 中的每一项
Args:
style: 要查找的 style
Returns:
找到的 Expression 对象,如果没有找到则返回 None
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
for expr in all_expressions:
# 只检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
if style in style_list:
return expr
return None
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
"""
查找具有相似 style 的 Expression 记录
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
Args:
style: 要查找的 style
situation: 要查找的 situation
similarity_threshold: 相似度阈值,默认 0.75
Returns:
找到的最相似的 Expression 对象,如果没有找到则返回 None
Tuple[Optional[Expression], float]:
- 找到的最相似的 Expression 对象,如果没有找到则返回 None
- 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间)
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
@@ -592,96 +393,28 @@ class ExpressionLearner:
best_similarity = 0.0
for expr in all_expressions:
# 检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(style, existing_style)
# 检查 content_list 中的每一项
content_list = self._parse_content_list(expr.content_list)
for existing_situation in content_list:
similarity = calculate_similarity(situation, existing_situation)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
logger.debug(f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'")
return best_match
return best_match, best_similarity
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
"""
组合 style 文本,如果 style_list 有多个元素则尝试总结
"""
sanitized = [s.strip() for s in style_list if s.strip()]
if len(sanitized) > 1:
# 只有当有多个 style 时才尝试总结
summary = await self._summarize_styles(sanitized)
if summary:
return summary
# 如果只有一个或总结失败,返回第一个或 fallback
return sanitized[0] if sanitized else fallback
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
"""总结多个 style生成一个概括性的 style 描述"""
if not styles or len(styles) <= 1:
return None
# 计算输入列表中最长项目的长度
max_input_length = max(len(s) for s in styles) if styles else 0
max_summary_length = max_input_length * 2
# 最多重试3次
max_retries = 3
retry_count = 0
while retry_count < max_retries:
# 如果是重试,在 prompt 中强调要更简洁
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字"
prompt = (
"请阅读以下多个语言风格/表达方式,对其进行总结。"
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
f"{length_hint}保留共同特点:\n"
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
# 检查总结长度是否超过限制
if len(summary) <= max_summary_length:
return summary
else:
retry_count += 1
logger.debug(
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
f"(输入最长项长度: {max_input_length}),重试第 {retry_count}"
)
continue
except Exception as e:
logger.error(f"概括表达风格失败: {e}")
return None
# 如果重试多次后仍然超过长度,返回 None不进行总结
logger.warning(
f"总结多次后仍超过长度限制,放弃总结。"
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
)
return None
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
if not sanitized:
return fallback
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
@@ -691,7 +424,64 @@ class ExpressionLearner:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
return "/".join(sanitized) if sanitized else fallback
async def _init_check_model(self) -> None:
"""初始化检查用的 LLM 实例"""
if self.check_model is None:
try:
self.check_model = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression.check"
)
logger.debug("检查用 LLM 实例初始化成功")
except Exception as e:
logger.error(f"创建检查用 LLM 实例失败: {e}")
async def _check_expression_immediately(self, expr_obj: Expression) -> None:
"""
立即检查表达方式(在 count 增加后调用)
Args:
expr_obj: 要检查的表达方式对象
"""
try:
# 检查是否启用自动检查
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过立即检查")
return
# 初始化检查用的 LLM
await self._init_check_model()
if self.check_model is None:
logger.warning("检查用 LLM 实例初始化失败,跳过立即检查")
return
# 执行 LLM 评估
suitable, reason, error = await single_expression_check(
expr_obj.situation,
expr_obj.style
)
# 更新数据库
expr_obj.checked = True
expr_obj.rejected = not suitable # 通过则 rejected=False不通过则 rejected=True
expr_obj.save()
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | "
f"Situation: {expr_obj.situation[:30]}... | "
f"Style: {expr_obj.style[:30]}... | "
f"Reason: {reason[:50] if reason else ''}..."
)
if error:
logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}")
except Exception as e:
logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True)
# 检查失败时,保持 checked=False等待后续自动检查任务处理
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
"""