Merge branch 'dev' into feat-lpmm知识库加强

This commit is contained in:
Dawn ARC
2025-12-18 18:58:10 +08:00
committed by GitHub
166 changed files with 22730 additions and 3447 deletions

View File

@@ -0,0 +1,567 @@
"""
模拟 Expression 合并过程
用法:
python scripts/expression_merge_simulation.py
或指定 chat_id:
python scripts/expression_merge_simulation.py --chat-id <chat_id>
或指定相似度阈值:
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
"""
import sys
import os
import json
import argparse
import asyncio
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def parse_content_list(stored_list: Optional[str]) -> List[str]:
"""解析 content_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def parse_style_list(stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def find_exact_style_match(
expressions: List[Expression],
target_style: str,
chat_id: str,
exclude_ids: set
) -> Optional[Expression]:
"""
查找具有完全匹配 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
"""
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
if expr.style == target_style:
return expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
if target_style in style_list:
return expr
return None
def find_similar_style_expression(
expressions: List[Expression],
target_style: str,
chat_id: str,
similarity_threshold: float,
exclude_ids: set
) -> Optional[Tuple[Expression, float]]:
"""
查找具有相似 style 的 Expression 记录
检查 style 字段和 style_list 中的每一项
Returns:
(Expression, similarity) 或 None
"""
best_match = None
best_similarity = 0.0
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
similarity = calculate_style_similarity(target_style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(target_style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
return (best_match, best_similarity)
return None
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
"""组合 situation 文本,尝试使用 LLM 总结"""
sanitized = [c.strip() for c in content_list if c.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 situation 失败: {e}")
# 如果总结失败,返回用 "/" 连接的字符串
return "/".join(sanitized)
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
"""组合 style 文本,尝试使用 LLM 总结"""
sanitized = [s.strip() for s in style_list if s.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
print(f"Prompt:{prompt} Summary:{summary}")
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 style 失败: {e}")
# 如果总结失败,返回第一个
return sanitized[0]
async def simulate_merge(
expressions: List[Expression],
similarity_threshold: float = 0.75,
use_llm: bool = False,
max_samples: int = 10,
) -> Dict:
"""
模拟合并过程
Args:
expressions: Expression 列表(从数据库读出的原始记录)
similarity_threshold: style 相似度阈值
use_llm: 是否使用 LLM 进行实际总结
max_samples: 最多随机抽取的 Expression 数量(为 0 或 None 时表示不限制)
Returns:
包含合并统计信息的字典
"""
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
if max_samples and len(expressions) > max_samples:
expressions = random.sample(expressions, max_samples)
# 按 chat_id 分组
expressions_by_chat = defaultdict(list)
for expr in expressions:
expressions_by_chat[expr.chat_id].append(expr)
# 初始化 LLM 模型(如果需要)
summary_model = None
if use_llm:
try:
summary_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.summary"
)
print("✅ LLM 模型已初始化,将进行实际总结")
except Exception as e:
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
use_llm = False
merge_stats = {
"total_expressions": len(expressions),
"total_chats": len(expressions_by_chat),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merge_details": [],
"chat_stats": {},
"use_llm": use_llm
}
# 为每个 chat_id 模拟合并
for chat_id, chat_expressions in expressions_by_chat.items():
chat_name = get_chat_name(chat_id)
chat_stat = {
"chat_id": chat_id,
"chat_name": chat_name,
"total": len(chat_expressions),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merges": []
}
processed_ids = set()
for expr in chat_expressions:
if expr.id in processed_ids:
continue
target_style = expr.style
target_situation = expr.situation
# 第一层:检查完全匹配
exact_match = find_exact_style_match(
chat_expressions,
target_style,
chat_id,
{expr.id}
)
if exact_match:
# 完全匹配(不使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(exact_match.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(exact_match.style_list)
if exact_match.style and exact_match.style not in target_style_list:
target_style_list.append(exact_match.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
merge_info = {
"type": "exact",
"source_id": expr.id,
"target_id": exact_match.id,
"source_style": target_style,
"target_style": exact_match.style,
"source_situation": target_situation,
"target_situation": exact_match.situation,
"similarity": 1.0,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
"merged_style": exact_match.style # 完全匹配时保持原 style
}
chat_stat["exact_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["exact_matches"] += 1
processed_ids.add(expr.id)
continue
# 第二层:检查相似匹配
similar_match = find_similar_style_expression(
chat_expressions,
target_style,
chat_id,
similarity_threshold,
{expr.id}
)
if similar_match:
match_expr, similarity = similar_match
# 相似匹配(使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(match_expr.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(match_expr.style_list)
if match_expr.style and match_expr.style not in target_style_list:
target_style_list.append(match_expr.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
# 使用 LLM 总结(如果启用)
merged_situation = match_expr.situation
merged_style = match_expr.style or target_style
if use_llm and summary_model:
try:
merged_situation = await compose_situation_text(target_content_list, summary_model)
merged_style = await compose_style_text(target_style_list, summary_model)
except Exception as e:
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
# 如果总结失败,使用 fallback
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
else:
# 不使用 LLM 时,使用简单拼接
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
merge_info = {
"type": "similar",
"source_id": expr.id,
"target_id": match_expr.id,
"source_style": target_style,
"target_style": match_expr.style,
"source_situation": target_situation,
"target_situation": match_expr.situation,
"similarity": similarity,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": merged_situation,
"merged_style": merged_style,
"llm_used": use_llm and summary_model is not None
}
chat_stat["similar_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["similar_matches"] += 1
processed_ids.add(expr.id)
continue
# 没有匹配,作为新记录
chat_stat["new_records"] += 1
merge_stats["new_records"] += 1
processed_ids.add(expr.id)
merge_stats["chat_stats"][chat_id] = chat_stat
merge_stats["merge_details"].extend(chat_stat["merges"])
return merge_stats
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
"""打印合并结果"""
print("\n" + "=" * 80)
print("Expression 合并模拟结果")
print("=" * 80)
print("\n📊 总体统计:")
print(f" 总 Expression 数: {stats['total_expressions']}")
print(f" 总聊天数: {stats['total_chats']}")
print(f" 完全匹配合并: {stats['exact_matches']}")
print(f" 相似匹配合并: {stats['similar_matches']}")
print(f" 新记录(无匹配): {stats['new_records']}")
if stats.get('use_llm'):
print(" LLM 总结: 已启用")
else:
print(" LLM 总结: 未启用(仅模拟)")
total_merges = stats['exact_matches'] + stats['similar_matches']
if stats['total_expressions'] > 0:
merge_ratio = (total_merges / stats['total_expressions']) * 100
print(f" 合并比例: {merge_ratio:.1f}%")
# 按聊天分组显示
print("\n📋 按聊天分组统计:")
for chat_id, chat_stat in stats['chat_stats'].items():
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
print(f" 总数: {chat_stat['total']}")
print(f" 完全匹配: {chat_stat['exact_matches']}")
print(f" 相似匹配: {chat_stat['similar_matches']}")
print(f" 新记录: {chat_stat['new_records']}")
# 显示合并详情
if show_details and stats['merge_details']:
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
print()
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
print(f" {idx}. {merge_type}")
print(f" 源记录 ID: {merge['source_id']}")
print(f" 目标记录 ID: {merge['target_id']}")
print(f" 源 Style: {merge['source_style'][:50]}")
print(f" 目标 Style: {merge['target_style'][:50]}")
print(f" 源 Situation: {merge['source_situation'][:50]}")
print(f" 目标 Situation: {merge['target_situation'][:50]}")
# 显示合并后的结果
if 'merged_situation' in merge:
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
if 'merged_style' in merge:
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
if merge.get('llm_used'):
print(" → LLM 总结: 已使用")
elif merge['type'] == 'similar':
print(" → LLM 总结: 未使用(模拟模式)")
# 显示合并后的列表
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
if len(merge['merged_content_list']) > 3:
print(f" ... 还有 {len(merge['merged_content_list']) - 3}")
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
if len(merge['merged_style_list']) > 3:
print(f" ... 还有 {len(merge['merged_style_list']) - 3}")
print()
if len(stats['merge_details']) > max_details:
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
parser.add_argument(
"--chat-id",
type=str,
default=None,
help="指定要分析的 chat_id不指定则分析所有"
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.75,
help="相似度阈值 (0-1, 默认: 0.75)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-details",
type=int,
default=50,
help="最多显示的合并详情数 (默认: 50)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
parser.add_argument(
"--use-llm",
action="store_true",
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM"
)
parser.add_argument(
"--max-samples",
type=int,
default=10,
help="最多随机抽取的 Expression 数量 (默认: 10设置为 0 表示不限制)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.similarity_threshold <= 1:
print("错误: similarity-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
# 查询 Expression 记录
print("正在从数据库加载Expression数据...")
try:
if args.chat_id:
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
else:
expressions = list(Expression.select())
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
# 执行合并模拟
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples}...")
if args.use_llm:
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
else:
print(" 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
stats = asyncio.run(
simulate_merge(
expressions,
similarity_threshold=args.similarity_threshold,
use_llm=args.use_llm,
max_samples=args.max_samples,
)
)
# 输出结果
original_stdout = sys.stdout
try:
with open(output_file, "w", encoding="utf-8") as f:
sys.stdout = f
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
sys.stdout = original_stdout
# 同时在控制台输出
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
except Exception as e:
sys.stdout = original_stdout
print(f"❌ 写入文件失败: {e}")
return
print(f"\n✅ 模拟结果已保存到: {output_file}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,564 @@
"""
分析expression库中situation和style的相似度
用法:
python scripts/expression_similarity_analysis.py
或指定阈值:
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
"""
import sys
import os
import argparse
from typing import List, Tuple
from collections import defaultdict
from difflib import SequenceMatcher
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.config.config import global_config # noqa: E402
import hashlib # noqa: E402
class TeeOutput:
"""同时输出到控制台和文件的类"""
def __init__(self, file_path: str):
self.file = open(file_path, "w", encoding="utf-8")
self.console = sys.stdout
def write(self, text: str):
"""写入文本到控制台和文件"""
self.console.write(text)
self.file.write(text)
self.file.flush() # 立即刷新到文件
def flush(self):
"""刷新输出"""
self.console.flush()
self.file.flush()
def close(self):
"""关闭文件"""
if self.file:
self.file.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""
解析'platform:id:type'为chat_id与ExpressionSelector中的逻辑一致
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def build_chat_id_groups() -> dict[str, set[str]]:
"""
根据expression_groups配置构建chat_id到相关chat_id集合的映射
Returns:
dict: {chat_id: set of related chat_ids (including itself)}
"""
groups = global_config.expression.expression_groups
chat_id_groups: dict[str, set[str]] = {}
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组收集所有配置中的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if stream_config_str == "*":
continue
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
# 所有chat_id都互相相关
for chat_id in all_chat_ids:
chat_id_groups[chat_id] = all_chat_ids.copy()
else:
# 处理普通组
for group in groups:
group_chat_ids = set()
for stream_config_str in group:
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.add(chat_id_candidate)
# 组内的所有chat_id都互相相关
for chat_id in group_chat_ids:
if chat_id not in chat_id_groups:
chat_id_groups[chat_id] = set()
chat_id_groups[chat_id].update(group_chat_ids)
# 确保每个chat_id至少包含自身
for chat_id in chat_id_groups:
chat_id_groups[chat_id].add(chat_id)
return chat_id_groups
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
"""
判断两个chat_id是否相关相同或同组
Args:
chat_id1: 第一个chat_id
chat_id2: 第二个chat_id
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
bool: 如果两个chat_id相同或同组返回True
"""
if chat_id1 == chat_id2:
return True
# 如果chat_id1在映射中检查chat_id2是否在其相关集合中
if chat_id1 in chat_id_groups:
return chat_id2 in chat_id_groups[chat_id1]
# 如果chat_id1不在映射中说明它不在任何组中只与自己相关
return False
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def text_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
使用SequenceMatcher计算相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词
"""
if not text1 or not text2:
return 0.0
# 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词"""
text = text.replace("使用", "")
text = text.replace("句式", "")
return text.strip()
cleaned_text1 = remove_ignored_words(text1)
cleaned_text2 = remove_ignored_words(text2)
# 如果清理后文本为空返回0
if not cleaned_text1 or not cleaned_text2:
return 0.0
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
def find_similar_pairs(
expressions: List[Expression],
field_name: str,
threshold: float,
max_pairs: int = None
) -> List[Tuple[int, int, float, str, str]]:
"""
找出相似的expression对
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation''style')
threshold: 相似度阈值 (0-1)
max_pairs: 最多返回的对数None表示返回所有
Returns:
List of (index1, index2, similarity, text1, text2) tuples
"""
similar_pairs = []
n = len(expressions)
print(f"正在分析 {field_name} 字段的相似度...")
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
similar_pairs.append((i, j, similarity, text1, text2))
# 按相似度降序排序
similar_pairs.sort(key=lambda x: x[2], reverse=True)
if max_pairs:
similar_pairs = similar_pairs[:max_pairs]
return similar_pairs
def group_similar_items(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]]
) -> List[List[int]]:
"""
将相似的expression分组仅比较相同chat_id或同组的项目
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation''style')
threshold: 相似度阈值 (0-1)
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
List of groups, each group is a list of indices
"""
n = len(expressions)
# 使用并查集的思想来分组
parent = list(range(n))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px != py:
parent[px] = py
print(f"正在对 {field_name} 字段进行分组仅比较相同chat_id或同组的项目...")
# 统计需要比较的对数
total_pairs = 0
for i in range(n):
for j in range(i + 1, n):
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
total_pairs += 1
print(f"总共需要比较 {total_pairs}已过滤不同chat_id且不同组的项目...")
compared_pairs = 0
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
# 只比较相同chat_id或同组的项目
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
continue
compared_pairs += 1
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
union(i, j)
# 收集分组
groups = defaultdict(list)
for i in range(n):
root = find(i)
groups[root].append(i)
# 只返回包含多个项目的组
result = [group for group in groups.values() if len(group) > 1]
result.sort(key=len, reverse=True)
return result
def print_similarity_analysis(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]],
show_details: bool = True,
max_groups: int = 20
):
"""打印相似度分析结果"""
print("\n" + "=" * 80)
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
print("=" * 80)
# 分组分析
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
total_items = len(expressions)
similar_items_count = sum(len(group) for group in groups)
unique_groups = len(groups)
print("\n📊 统计信息:")
print(f" 总项目数: {total_items}")
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
print(f" 相似组数: {unique_groups}")
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
if not groups:
print(f"\n未找到相似度 >= {threshold} 的项目组")
return
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
print()
for group_idx, group in enumerate(groups[:max_groups], 1):
print(f"{group_idx} (共 {len(group)} 个项目):")
if show_details:
# 显示组内所有项目的详细信息
for idx in group:
expr = expressions[idx]
text = getattr(expr, field_name, "")
chat_name = get_chat_name(expr.chat_id)
# 截断过长的文本
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [{expr.id}] {display_text}")
print(f" 聊天: {chat_name}, Count: {expr.count}")
# 计算组内平均相似度
if len(group) > 1:
similarities = []
above_threshold_pairs = [] # 存储满足阈值的相似对
above_threshold_count = 0
for i in range(len(group)):
for j in range(i + 1, len(group)):
text1 = getattr(expressions[group[i]], field_name, "")
text2 = getattr(expressions[group[j]], field_name, "")
sim = text_similarity(text1, text2)
similarities.append(sim)
if sim >= threshold:
above_threshold_count += 1
# 存储满足阈值的对的信息
expr1 = expressions[group[i]]
expr2 = expressions[group[j]]
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
above_threshold_pairs.append((
expr1.id, display_text1,
expr2.id, display_text2,
sim
))
if similarities:
avg_sim = sum(similarities) / len(similarities)
min_sim = min(similarities)
max_sim = max(similarities)
above_threshold_ratio = above_threshold_count / len(similarities) * 100
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
if above_threshold_pairs:
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
# 按相似度降序排序
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
print(f" \"{text1}\"\"{text2}\"")
if len(above_threshold_pairs) > 10:
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
else:
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
else:
# 只显示组内第一个项目作为示例
expr = expressions[group[0]]
text = getattr(expr, field_name, "")
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" 示例: {display_text}")
print(f" ... 还有 {len(group) - 1} 个相似项目")
print()
if len(groups) > max_groups:
print(f"... 还有 {len(groups) - max_groups} 组未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
parser.add_argument(
"--situation-threshold",
type=float,
default=0.7,
help="situation相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--style-threshold",
type=float,
default=0.7,
help="style相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-groups",
type=int,
default=20,
help="最多显示的组数 (默认: 20)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.situation_threshold <= 1:
print("错误: situation-threshold 必须在 0-1 之间")
return
if not 0 <= args.style_threshold <= 1:
print("错误: style-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
# 自动生成带时间戳的输出文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
# 使用TeeOutput同时输出到控制台和文件
with TeeOutput(output_file) as tee:
# 临时替换sys.stdout
original_stdout = sys.stdout
sys.stdout = tee
try:
print("=" * 80)
print("Expression 相似度分析工具")
print("=" * 80)
print(f"输出文件: {output_file}")
print()
_run_analysis(args)
finally:
# 恢复原始stdout
sys.stdout = original_stdout
print(f"\n✅ 分析结果已保存到: {output_file}")
def _run_analysis(args):
"""执行分析的主逻辑"""
# 查询所有Expression记录
print("正在从数据库加载Expression数据...")
try:
expressions = list(Expression.select())
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
print()
# 构建chat_id分组映射
print("正在构建chat_id分组映射根据expression_groups配置...")
try:
chat_id_groups = build_chat_id_groups()
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
if chat_id_groups:
# 统计分组信息
total_related = sum(len(related) for related in chat_id_groups.values())
avg_related = total_related / len(chat_id_groups)
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关包括自身")
print()
except Exception as e:
print(f"⚠️ 构建chat_id分组映射失败: {e}")
print(" 将使用默认行为只比较相同chat_id的项目")
chat_id_groups = {}
# 分析situation相似度
print_similarity_analysis(
expressions,
"situation",
args.situation_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
# 分析style相似度
print_similarity_analysis(
expressions,
"style",
args.style_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
print("\n" + "=" * 80)
print("分析完成!")
print("=" * 80)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,303 @@
"""
统计和展示 replyer 动作选择记录
用法:
python scripts/replyer_action_stats.py
"""
import json
import os
import sys
from collections import Counter, defaultdict
from datetime import datetime
from typing import Dict, List, Any
from pathlib import Path
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
try:
from src.common.database.database_model import ChatStreams
from src.chat.message_receive.chat_stream import get_chat_manager
except ImportError:
ChatStreams = None
get_chat_manager = None
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
if ChatStreams:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
if get_chat_manager:
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(chat_id)
if stream_name:
return stream_name
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
"""加载所有 replyer 动作记录"""
records = []
temp_path = Path(temp_dir)
if not temp_path.exists():
print(f"目录不存在: {temp_dir}")
return records
# 查找所有 replyer_action_*.json 文件
pattern = "replyer_action_*.json"
for file_path in temp_path.glob(pattern):
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
records.append(data)
except Exception as e:
print(f"读取文件失败 {file_path}: {e}")
# 按时间戳排序
records.sort(key=lambda x: x.get("timestamp", ""))
return records
def format_timestamp(ts: str) -> str:
"""格式化时间戳"""
try:
dt = datetime.fromisoformat(ts)
return dt.strftime("%Y-%m-%d %H:%M:%S")
except Exception:
return ts
def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]:
"""计算时间分布"""
now = datetime.now()
distribution = {
"今天": 0,
"昨天": 0,
"3天内": 0,
"7天内": 0,
"30天内": 0,
"更早": 0,
}
for record in records:
try:
ts = record.get("timestamp", "")
if not ts:
continue
dt = datetime.fromisoformat(ts)
diff = (now - dt).days
if diff == 0:
distribution["今天"] += 1
elif diff == 1:
distribution["昨天"] += 1
elif diff < 3:
distribution["3天内"] += 1
elif diff < 7:
distribution["7天内"] += 1
elif diff < 30:
distribution["30天内"] += 1
else:
distribution["更早"] += 1
except Exception:
pass
return distribution
def print_statistics(records: List[Dict[str, Any]]):
"""打印统计信息"""
if not records:
print("没有找到任何记录")
return
print("=" * 80)
print("Replyer 动作选择记录统计")
print("=" * 80)
print()
# 总记录数
total_count = len(records)
print(f"📊 总记录数: {total_count}")
print()
# 时间范围
timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")]
if timestamps:
first_time = format_timestamp(min(timestamps))
last_time = format_timestamp(max(timestamps))
print(f"📅 时间范围: {first_time} ~ {last_time}")
print()
# 按 think_level 统计
think_levels = [r.get("think_level", 0) for r in records]
think_level_counter = Counter(think_levels)
print("🧠 思考深度分布:")
for level in sorted(think_level_counter.keys()):
count = think_level_counter[level]
percentage = (count / total_count) * 100
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)")
print()
# 按 chat_id 统计(总体)
chat_counter = Counter([r.get("chat_id", "未知") for r in records])
print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):")
# 只显示前10个
for chat_id, count in chat_counter.most_common(10):
chat_name = get_chat_name(chat_id)
percentage = (count / total_count) * 100
print(f" {chat_name}: {count} 次 ({percentage:.1f}%)")
if len(chat_counter) > 10:
print(f" ... 还有 {len(chat_counter) - 10} 个聊天")
print()
# 每个 chat_id 的详细统计
print("=" * 80)
print("每个聊天的详细统计")
print("=" * 80)
print()
# 按 chat_id 分组记录
records_by_chat = defaultdict(list)
for record in records:
chat_id = record.get("chat_id", "未知")
records_by_chat[chat_id].append(record)
# 按记录数排序
sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True)
for chat_id, chat_records in sorted_chats:
chat_name = get_chat_name(chat_id)
chat_count = len(chat_records)
chat_percentage = (chat_count / total_count) * 100
print(f"📱 {chat_name} ({chat_id[:8]}...)")
print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)")
# 该聊天的 think_level 分布
chat_think_levels = [r.get("think_level", 0) for r in chat_records]
chat_think_counter = Counter(chat_think_levels)
print(" 思考深度分布:")
for level in sorted(chat_think_counter.keys()):
level_count = chat_think_counter[level]
level_percentage = (level_count / chat_count) * 100
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)")
# 该聊天的时间范围
chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")]
if chat_timestamps:
first_time = format_timestamp(min(chat_timestamps))
last_time = format_timestamp(max(chat_timestamps))
print(f" 时间范围: {first_time} ~ {last_time}")
# 该聊天的时间分布
chat_time_dist = calculate_time_distribution(chat_records)
print(" 时间分布:")
for period, count in chat_time_dist.items():
if count > 0:
period_percentage = (count / chat_count) * 100
print(f" {period}: {count} 次 ({period_percentage:.1f}%)")
# 显示该聊天最近的一条理由示例
if chat_records:
latest_record = chat_records[-1]
reason = latest_record.get("reason", "无理由")
if len(reason) > 120:
reason = reason[:120] + "..."
timestamp = format_timestamp(latest_record.get("timestamp", ""))
think_level = latest_record.get("think_level", 0)
print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}")
print()
# 时间分布
time_dist = calculate_time_distribution(records)
print("⏰ 时间分布:")
for period, count in time_dist.items():
if count > 0:
percentage = (count / total_count) * 100
print(f" {period}: {count} 次 ({percentage:.1f}%)")
print()
# 显示一些示例理由
print("📝 示例理由 (最近5条):")
recent_records = records[-5:]
for i, record in enumerate(recent_records, 1):
reason = record.get("reason", "无理由")
think_level = record.get("think_level", 0)
timestamp = format_timestamp(record.get("timestamp", ""))
chat_id = record.get("chat_id", "未知")
chat_name = get_chat_name(chat_id)
# 截断过长的理由
if len(reason) > 100:
reason = reason[:100] + "..."
print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})")
print(f" {reason}")
print()
# 按 think_level 分组显示理由示例
print("=" * 80)
print("按思考深度分类的示例理由")
print("=" * 80)
print()
for level in [0, 1, 2]:
level_records = [r for r in records if r.get("think_level") == level]
if not level_records:
continue
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:")
# 显示3个示例选择最近的
examples = level_records[-3:] if len(level_records) >= 3 else level_records
for i, record in enumerate(examples, 1):
reason = record.get("reason", "无理由")
if len(reason) > 150:
reason = reason[:150] + "..."
timestamp = format_timestamp(record.get("timestamp", ""))
chat_id = record.get("chat_id", "未知")
chat_name = get_chat_name(chat_id)
print(f" {i}. [{timestamp}] {chat_name}")
print(f" {reason}")
print()
# 统计信息汇总
print("=" * 80)
print("统计汇总")
print("=" * 80)
print(f"总记录数: {total_count}")
print(f"涉及聊天数: {len(chat_counter)}")
if chat_counter:
avg_count = total_count / len(chat_counter)
print(f"平均每个聊天记录数: {avg_count:.1f}")
else:
print("平均每个聊天记录数: N/A")
print()
def main():
"""主函数"""
records = load_records()
print_statistics(records)
if __name__ == "__main__":
main()