better:优化表达方式学习和分割

This commit is contained in:
SengokuCola
2025-12-18 16:39:16 +08:00
parent dd891c4b18
commit f7a2f2329a
7 changed files with 1400 additions and 35 deletions

View File

@@ -0,0 +1,567 @@
"""
模拟 Expression 合并过程
用法:
python scripts/expression_merge_simulation.py
或指定 chat_id:
python scripts/expression_merge_simulation.py --chat-id <chat_id>
或指定相似度阈值:
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
"""
import sys
import os
import json
import argparse
import asyncio
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def parse_content_list(stored_list: Optional[str]) -> List[str]:
"""解析 content_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def parse_style_list(stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def find_exact_style_match(
expressions: List[Expression],
target_style: str,
chat_id: str,
exclude_ids: set
) -> Optional[Expression]:
"""
查找具有完全匹配 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
"""
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
if expr.style == target_style:
return expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
if target_style in style_list:
return expr
return None
def find_similar_style_expression(
expressions: List[Expression],
target_style: str,
chat_id: str,
similarity_threshold: float,
exclude_ids: set
) -> Optional[Tuple[Expression, float]]:
"""
查找具有相似 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
Returns:
(Expression, similarity) 或 None
"""
best_match = None
best_similarity = 0.0
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
similarity = calculate_style_similarity(target_style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(target_style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
return (best_match, best_similarity)
return None
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
"""组合 situation 文本,尝试使用 LLM 总结"""
sanitized = [c.strip() for c in content_list if c.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 situation 失败: {e}")
# 如果总结失败,返回用 "/" 连接的字符串
return "/".join(sanitized)
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
"""组合 style 文本,尝试使用 LLM 总结"""
sanitized = [s.strip() for s in style_list if s.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
print(f"Prompt:{prompt} Summary:{summary}")
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 style 失败: {e}")
# 如果总结失败,返回第一个
return sanitized[0]
async def simulate_merge(
expressions: List[Expression],
similarity_threshold: float = 0.75,
use_llm: bool = False,
max_samples: int = 10,
) -> Dict:
"""
模拟合并过程
Args:
expressions: Expression 列表(从数据库读出的原始记录)
similarity_threshold: style 相似度阈值
use_llm: 是否使用 LLM 进行实际总结
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
Returns:
包含合并统计信息的字典
"""
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
if max_samples and len(expressions) > max_samples:
expressions = random.sample(expressions, max_samples)
# 按 chat_id 分组
expressions_by_chat = defaultdict(list)
for expr in expressions:
expressions_by_chat[expr.chat_id].append(expr)
# 初始化 LLM 模型(如果需要)
summary_model = None
if use_llm:
try:
summary_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.summary"
)
print("✅ LLM 模型已初始化,将进行实际总结")
except Exception as e:
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
use_llm = False
merge_stats = {
"total_expressions": len(expressions),
"total_chats": len(expressions_by_chat),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merge_details": [],
"chat_stats": {},
"use_llm": use_llm
}
# 为每个 chat_id 模拟合并
for chat_id, chat_expressions in expressions_by_chat.items():
chat_name = get_chat_name(chat_id)
chat_stat = {
"chat_id": chat_id,
"chat_name": chat_name,
"total": len(chat_expressions),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merges": []
}
processed_ids = set()
for expr in chat_expressions:
if expr.id in processed_ids:
continue
target_style = expr.style
target_situation = expr.situation
# 第一层:检查完全匹配
exact_match = find_exact_style_match(
chat_expressions,
target_style,
chat_id,
{expr.id}
)
if exact_match:
# 完全匹配(不使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(exact_match.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(exact_match.style_list)
if exact_match.style and exact_match.style not in target_style_list:
target_style_list.append(exact_match.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
merge_info = {
"type": "exact",
"source_id": expr.id,
"target_id": exact_match.id,
"source_style": target_style,
"target_style": exact_match.style,
"source_situation": target_situation,
"target_situation": exact_match.situation,
"similarity": 1.0,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
"merged_style": exact_match.style # 完全匹配时保持原 style
}
chat_stat["exact_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["exact_matches"] += 1
processed_ids.add(expr.id)
continue
# 第二层:检查相似匹配
similar_match = find_similar_style_expression(
chat_expressions,
target_style,
chat_id,
similarity_threshold,
{expr.id}
)
if similar_match:
match_expr, similarity = similar_match
# 相似匹配(使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(match_expr.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(match_expr.style_list)
if match_expr.style and match_expr.style not in target_style_list:
target_style_list.append(match_expr.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
# 使用 LLM 总结(如果启用)
merged_situation = match_expr.situation
merged_style = match_expr.style or target_style
if use_llm and summary_model:
try:
merged_situation = await compose_situation_text(target_content_list, summary_model)
merged_style = await compose_style_text(target_style_list, summary_model)
except Exception as e:
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
# 如果总结失败,使用 fallback
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
else:
# 不使用 LLM 时,使用简单拼接
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
merge_info = {
"type": "similar",
"source_id": expr.id,
"target_id": match_expr.id,
"source_style": target_style,
"target_style": match_expr.style,
"source_situation": target_situation,
"target_situation": match_expr.situation,
"similarity": similarity,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": merged_situation,
"merged_style": merged_style,
"llm_used": use_llm and summary_model is not None
}
chat_stat["similar_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["similar_matches"] += 1
processed_ids.add(expr.id)
continue
# 没有匹配,作为新记录
chat_stat["new_records"] += 1
merge_stats["new_records"] += 1
processed_ids.add(expr.id)
merge_stats["chat_stats"][chat_id] = chat_stat
merge_stats["merge_details"].extend(chat_stat["merges"])
return merge_stats
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
"""打印合并结果"""
print("\n" + "=" * 80)
print("Expression 合并模拟结果")
print("=" * 80)
print("\n📊 总体统计:")
print(f" 总 Expression 数: {stats['total_expressions']}")
print(f" 总聊天数: {stats['total_chats']}")
print(f" 完全匹配合并: {stats['exact_matches']}")
print(f" 相似匹配合并: {stats['similar_matches']}")
print(f" 新记录(无匹配): {stats['new_records']}")
if stats.get('use_llm'):
print(" LLM 总结: 已启用")
else:
print(" LLM 总结: 未启用(仅模拟)")
total_merges = stats['exact_matches'] + stats['similar_matches']
if stats['total_expressions'] > 0:
merge_ratio = (total_merges / stats['total_expressions']) * 100
print(f" 合并比例: {merge_ratio:.1f}%")
# 按聊天分组显示
print("\n📋 按聊天分组统计:")
for chat_id, chat_stat in stats['chat_stats'].items():
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
print(f" 总数: {chat_stat['total']}")
print(f" 完全匹配: {chat_stat['exact_matches']}")
print(f" 相似匹配: {chat_stat['similar_matches']}")
print(f" 新记录: {chat_stat['new_records']}")
# 显示合并详情
if show_details and stats['merge_details']:
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
print()
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
print(f" {idx}. {merge_type}")
print(f" 源记录 ID: {merge['source_id']}")
print(f" 目标记录 ID: {merge['target_id']}")
print(f" 源 Style: {merge['source_style'][:50]}")
print(f" 目标 Style: {merge['target_style'][:50]}")
print(f" 源 Situation: {merge['source_situation'][:50]}")
print(f" 目标 Situation: {merge['target_situation'][:50]}")
# 显示合并后的结果
if 'merged_situation' in merge:
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
if 'merged_style' in merge:
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
if merge.get('llm_used'):
print(" → LLM 总结: 已使用")
elif merge['type'] == 'similar':
print(" → LLM 总结: 未使用(模拟模式)")
# 显示合并后的列表
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
if len(merge['merged_content_list']) > 3:
print(f" ... 还有 {len(merge['merged_content_list']) - 3}")
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
if len(merge['merged_style_list']) > 3:
print(f" ... 还有 {len(merge['merged_style_list']) - 3}")
print()
if len(stats['merge_details']) > max_details:
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
parser.add_argument(
"--chat-id",
type=str,
default=None,
help="指定要分析的 chat_id不指定则分析所有"
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.75,
help="相似度阈值 (0-1, 默认: 0.75)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-details",
type=int,
default=50,
help="最多显示的合并详情数 (默认: 50)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
parser.add_argument(
"--use-llm",
action="store_true",
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM"
)
parser.add_argument(
"--max-samples",
type=int,
default=10,
help="最多随机抽取的 Expression 数量 (默认: 10设置为 0 表示不限制)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.similarity_threshold <= 1:
print("错误: similarity-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
# 查询 Expression 记录
print("正在从数据库加载Expression数据...")
try:
if args.chat_id:
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
else:
expressions = list(Expression.select())
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
# 执行合并模拟
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples}...")
if args.use_llm:
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
else:
print(" 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
stats = asyncio.run(
simulate_merge(
expressions,
similarity_threshold=args.similarity_threshold,
use_llm=args.use_llm,
max_samples=args.max_samples,
)
)
# 输出结果
original_stdout = sys.stdout
try:
with open(output_file, "w", encoding="utf-8") as f:
sys.stdout = f
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
sys.stdout = original_stdout
# 同时在控制台输出
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
except Exception as e:
sys.stdout = original_stdout
print(f"❌ 写入文件失败: {e}")
return
print(f"\n✅ 模拟结果已保存到: {output_file}")
if __name__ == "__main__":
main()