feat:新增自动表达优化功能,优化表达方式的提取
This commit is contained in:
@@ -18,10 +18,13 @@ from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
calculate_style_similarity,
|
||||
calculate_similarity,
|
||||
parse_expression_response,
|
||||
)
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
from json_repair import repair_json
|
||||
from src.bw_learner.expression_auto_check_task import (
|
||||
single_expression_check,
|
||||
)
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
@@ -91,6 +94,7 @@ class ExpressionLearner:
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.summary"
|
||||
)
|
||||
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
@@ -136,11 +140,10 @@ class ExpressionLearner:
|
||||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
expressions, jargon_entries = parse_expression_response(response)
|
||||
|
||||
# 检查表达方式数量,如果超过10个则放弃本次表达学习
|
||||
if len(expressions) > 10:
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"表达方式提取数量超过10个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
@@ -155,7 +158,7 @@ class ExpressionLearner:
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
logger.info("解析后没有可用的表达方式")
|
||||
return []
|
||||
|
||||
logger.info(f"学习的prompt: {prompt}")
|
||||
@@ -163,9 +166,60 @@ class ExpressionLearner:
|
||||
logger.info(f"学习的jargon_entries: {jargon_entries}")
|
||||
logger.info(f"学习的response: {response}")
|
||||
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
|
||||
learnt_expressions = self._filter_expressions(expressions, random_msg)
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (situation,style) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (situation,style) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def _filter_expressions(
|
||||
self,
|
||||
expressions: List[Tuple[str, str, str]],
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤表达方式,移除不符合条件的条目
|
||||
|
||||
Args:
|
||||
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
|
||||
messages: 原始消息列表,用于溯源和验证
|
||||
|
||||
Returns:
|
||||
过滤后的表达方式列表,每个元素是 (situation, style, context)
|
||||
"""
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
@@ -173,12 +227,12 @@ class ExpressionLearner:
|
||||
continue
|
||||
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
current_msg = messages[line_index]
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
@@ -195,251 +249,53 @@ class ExpressionLearner:
|
||||
)
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
# 过滤掉 style 与机器人名称/昵称重复的表达
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() in banned_casefold:
|
||||
logger.debug(
|
||||
f"跳过 style 与机器人名称重复的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
# 过滤掉包含 "表情:" 或 "表情:" 的内容
|
||||
if "表情:" in (situation or "") or "表情:" in (situation or "") or \
|
||||
"表情:" in (style or "") or "表情:" in (style or "") or \
|
||||
"表情:" in context or "表情:" in context:
|
||||
logger.info(
|
||||
f"跳过包含表情标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
# 过滤掉包含 "[图片" 的内容
|
||||
if "[图片" in (situation or "") or "[图片" in (style or "") or "[图片" in context:
|
||||
logger.info(
|
||||
f"跳过包含图片标记的表达方式: situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
return expressions, jargon_entries
|
||||
filtered_expressions.append((situation, style))
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
return filtered_expressions
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
# 第一层:检查是否有完全一致的 style(检查 style 字段和 style_list)
|
||||
expr_obj = await self._find_exact_style_match(style)
|
||||
# 检查是否有相似的 situation(相似度 >= 0.75,检查 content_list)
|
||||
# 完全匹配(相似度 == 1.0)和相似匹配(相似度 >= 0.75)统一处理
|
||||
expr_obj, similarity = await self._find_similar_situation_expression(situation, similarity_threshold=0.75)
|
||||
|
||||
if expr_obj:
|
||||
# 找到完全匹配的 style,合并到现有记录(不使用 LLM 总结)
|
||||
# 根据相似度决定是否使用 LLM 总结
|
||||
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
|
||||
use_llm_summary = similarity < 1.0
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=False,
|
||||
)
|
||||
return
|
||||
|
||||
# 第二层:检查是否有相似的 style(相似度 >= 0.75,检查 style 字段和 style_list)
|
||||
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
|
||||
|
||||
if similar_expr_obj:
|
||||
# 找到相似的 style,合并到现有记录(使用 LLM 总结)
|
||||
await self._update_existing_expression(
|
||||
expr_obj=similar_expr_obj,
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
use_llm_summary=True,
|
||||
use_llm_summary=use_llm_summary,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -447,7 +303,6 @@ class ExpressionLearner:
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
@@ -455,7 +310,6 @@ class ExpressionLearner:
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
@@ -466,26 +320,22 @@ class ExpressionLearner:
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
style_list=None, # 新记录初始时 style_list 为空
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
use_llm_summary: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
更新现有 Expression 记录(style 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,将新的 style 添加到 style_list(如果不同)
|
||||
更新现有 Expression 记录(situation 完全匹配或相似的情况)
|
||||
将新的 situation 添加到 content_list,不合并 style
|
||||
|
||||
Args:
|
||||
use_llm_summary: 是否使用 LLM 进行总结,完全匹配时为 False,相似匹配时为 True
|
||||
@@ -495,43 +345,24 @@ class ExpressionLearner:
|
||||
content_list.append(situation)
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
|
||||
# 更新 style_list(如果 style 不同,添加到 style_list)
|
||||
style_list = self._parse_style_list(expr_obj.style_list)
|
||||
# 将原有的 style 也加入 style_list(如果还没有的话)
|
||||
if expr_obj.style and expr_obj.style not in style_list:
|
||||
style_list.append(expr_obj.style)
|
||||
# 如果新的 style 不在 style_list 中,添加它
|
||||
if style not in style_list:
|
||||
style_list.append(style)
|
||||
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
|
||||
|
||||
# 更新其他字段
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.checked = False # count 增加时重置 checked 为 False
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
|
||||
if use_llm_summary:
|
||||
# 相似匹配时,使用 LLM 重新组合 situation 和 style
|
||||
# 相似匹配时,使用 LLM 重新组合 situation
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
new_style = await self._compose_style_text(
|
||||
style_list=style_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.style or style,
|
||||
)
|
||||
expr_obj.style = new_style
|
||||
else:
|
||||
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
|
||||
# 只更新 content_list 和 style_list
|
||||
pass
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
# count 增加后,立即进行一次检查
|
||||
await self._check_expression_immediately(expr_obj)
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
@@ -541,49 +372,19 @@ class ExpressionLearner:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
|
||||
async def _find_similar_situation_expression(self, situation: str, similarity_threshold: float = 0.75) -> Tuple[Optional[Expression], float]:
|
||||
"""
|
||||
查找具有完全匹配 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
查找具有相似 situation 的 Expression 记录
|
||||
检查 content_list 中的每一项
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
|
||||
Returns:
|
||||
找到的 Expression 对象,如果没有找到则返回 None
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
if style in style_list:
|
||||
return expr
|
||||
|
||||
return None
|
||||
|
||||
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
|
||||
"""
|
||||
查找具有相似 style 的 Expression 记录
|
||||
只检查 style_list 中的每一项(不检查 style 字段,因为 style 可能是总结后的概括性描述)
|
||||
|
||||
Args:
|
||||
style: 要查找的 style
|
||||
situation: 要查找的 situation
|
||||
similarity_threshold: 相似度阈值,默认 0.75
|
||||
|
||||
Returns:
|
||||
找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
Tuple[Optional[Expression], float]:
|
||||
- 找到的最相似的 Expression 对象,如果没有找到则返回 None
|
||||
- 相似度值(如果找到匹配,范围在 similarity_threshold 到 1.0 之间)
|
||||
"""
|
||||
# 查询同一 chat_id 的所有记录
|
||||
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
|
||||
@@ -592,96 +393,28 @@ class ExpressionLearner:
|
||||
best_similarity = 0.0
|
||||
|
||||
for expr in all_expressions:
|
||||
# 只检查 style_list 中的每一项
|
||||
style_list = self._parse_style_list(expr.style_list)
|
||||
for existing_style in style_list:
|
||||
similarity = calculate_style_similarity(style, existing_style)
|
||||
# 检查 content_list 中的每一项
|
||||
content_list = self._parse_content_list(expr.content_list)
|
||||
for existing_situation in content_list:
|
||||
similarity = calculate_similarity(situation, existing_situation)
|
||||
if similarity >= similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expr
|
||||
|
||||
if best_match:
|
||||
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
|
||||
logger.debug(f"找到相似的 situation: 相似度={best_similarity:.3f}, 现有='{best_match.situation}', 新='{situation}'")
|
||||
|
||||
return best_match
|
||||
return best_match, best_similarity
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
async def _compose_situation_text(self, content_list: List[str], fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
|
||||
"""
|
||||
组合 style 文本,如果 style_list 有多个元素则尝试总结
|
||||
"""
|
||||
sanitized = [s.strip() for s in style_list if s.strip()]
|
||||
if len(sanitized) > 1:
|
||||
# 只有当有多个 style 时才尝试总结
|
||||
summary = await self._summarize_styles(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
# 如果只有一个或总结失败,返回第一个或 fallback
|
||||
return sanitized[0] if sanitized else fallback
|
||||
|
||||
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
|
||||
"""总结多个 style,生成一个概括性的 style 描述"""
|
||||
if not styles or len(styles) <= 1:
|
||||
return None
|
||||
|
||||
# 计算输入列表中最长项目的长度
|
||||
max_input_length = max(len(s) for s in styles) if styles else 0
|
||||
max_summary_length = max_input_length * 2
|
||||
|
||||
# 最多重试3次
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < max_retries:
|
||||
# 如果是重试,在 prompt 中强调要更简洁
|
||||
length_hint = f"长度不超过{max_summary_length}个字符," if retry_count > 0 else "长度不超过20个字,"
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个语言风格/表达方式,对其进行总结。"
|
||||
"不要对其进行语义概括,而是尽可能找出其中不变的部分或共同表达,尽量使用原文"
|
||||
f"{length_hint}保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。不要输出其他内容"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
# 检查总结长度是否超过限制
|
||||
if len(summary) <= max_summary_length:
|
||||
return summary
|
||||
else:
|
||||
retry_count += 1
|
||||
logger.debug(
|
||||
f"总结长度 {len(summary)} 超过限制 {max_summary_length} "
|
||||
f"(输入最长项长度: {max_input_length}),重试第 {retry_count} 次"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达风格失败: {e}")
|
||||
return None
|
||||
|
||||
# 如果重试多次后仍然超过长度,返回 None(不进行总结)
|
||||
logger.warning(
|
||||
f"总结多次后仍超过长度限制,放弃总结。"
|
||||
f"输入最长项长度: {max_input_length}, 最大允许长度: {max_summary_length}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
if not sanitized:
|
||||
return fallback
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -691,7 +424,64 @@ class ExpressionLearner:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _init_check_model(self) -> None:
|
||||
"""初始化检查用的 LLM 实例"""
|
||||
if self.check_model is None:
|
||||
try:
|
||||
self.check_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression.check"
|
||||
)
|
||||
logger.debug("检查用 LLM 实例初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"创建检查用 LLM 实例失败: {e}")
|
||||
|
||||
async def _check_expression_immediately(self, expr_obj: Expression) -> None:
|
||||
"""
|
||||
立即检查表达方式(在 count 增加后调用)
|
||||
|
||||
Args:
|
||||
expr_obj: 要检查的表达方式对象
|
||||
"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过立即检查")
|
||||
return
|
||||
|
||||
# 初始化检查用的 LLM
|
||||
await self._init_check_model()
|
||||
if self.check_model is None:
|
||||
logger.warning("检查用 LLM 实例初始化失败,跳过立即检查")
|
||||
return
|
||||
|
||||
# 执行 LLM 评估
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expr_obj.situation,
|
||||
expr_obj.style
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
expr_obj.checked = True
|
||||
expr_obj.rejected = not suitable # 通过则 rejected=False,不通过则 rejected=True
|
||||
expr_obj.save()
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
logger.info(
|
||||
f"表达方式立即检查完成 [ID: {expr_obj.id}] - {status} | "
|
||||
f"Situation: {expr_obj.situation[:30]}... | "
|
||||
f"Style: {expr_obj.style[:30]}... | "
|
||||
f"Reason: {reason[:50] if reason else '无'}..."
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式立即检查时出现错误 [ID: {expr_obj.id}]: {error}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"立即检查表达方式失败 [ID: {expr_obj.id}]: {e}", exc_info=True)
|
||||
# 检查失败时,保持 checked=False,等待后续自动检查任务处理
|
||||
|
||||
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user