Ruff Format

This commit is contained in:
DrSmoothl
2026-02-21 16:24:24 +08:00
parent 2cb512120b
commit eaef7f0e98
82 changed files with 1881 additions and 1900 deletions

View File

@@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None:
def analyze_single_file(file_path: str) -> Dict:
"""
分析单个JSON文件的统计信息
Args:
file_path: JSON文件路径
Returns:
统计信息字典
"""
@@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict:
"has_reason": False,
"reason_count": 0,
}
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 基本信息
stats["last_updated"] = data.get("last_updated")
stats["total_count"] = data.get("total_count", 0)
results = data.get("manual_results", [])
stats["actual_count"] = len(results)
if not results:
return stats
# 统计通过/不通过
suitable_count = sum(1 for r in results if r.get("suitable") is True)
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
stats["suitable_count"] = suitable_count
stats["unsuitable_count"] = unsuitable_count
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
# 统计唯一的(situation, style)对
pairs: Set[Tuple[str, str]] = set()
for r in results:
if "situation" in r and "style" in r:
pairs.add((r["situation"], r["style"]))
stats["unique_pairs"] = len(pairs)
# 统计评估者
for r in results:
evaluator = r.get("evaluator", "unknown")
stats["evaluators"][evaluator] += 1
# 统计评估时间
evaluation_dates = []
for r in results:
@@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict:
dt = parse_datetime(evaluated_at)
if dt:
evaluation_dates.append(dt)
stats["evaluation_dates"] = evaluation_dates
if evaluation_dates:
min_date = min(evaluation_dates)
@@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict:
stats["date_range"] = {
"start": min_date.isoformat(),
"end": max_date.isoformat(),
"duration_days": (max_date - min_date).days + 1
"duration_days": (max_date - min_date).days + 1,
}
# 检查字段存在性
stats["has_expression_id"] = any("expression_id" in r for r in results)
stats["has_reason"] = any(r.get("reason") for r in results)
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
except Exception as e:
stats["error"] = str(e)
logger.error(f"分析文件 {file_name} 时出错: {e}")
return stats
@@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None):
print(f"\n{'=' * 80}")
print(f"{prefix}文件: {stats['file_name']}")
print(f"{'=' * 80}")
if stats["error"]:
print(f"✗ 错误: {stats['error']}")
return
print(f"文件路径: {stats['file_path']}")
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
if stats["last_updated"]:
print(f"最后更新: {stats['last_updated']}")
print("\n【记录统计】")
print(f" 文件中的 total_count: {stats['total_count']}")
print(f" 实际记录数: {stats['actual_count']}")
if stats['total_count'] != stats['actual_count']:
diff = stats['total_count'] - stats['actual_count']
if stats["total_count"] != stats["actual_count"]:
diff = stats["total_count"] - stats["actual_count"]
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
print("\n【评估结果统计】")
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
print("\n【唯一性统计】")
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']}")
if stats['actual_count'] > 0:
duplicate_count = stats['actual_count'] - stats['unique_pairs']
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["actual_count"] > 0:
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
print("\n【评估者统计】")
if stats['evaluators']:
for evaluator, count in stats['evaluators'].most_common():
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["evaluators"]:
for evaluator, count in stats["evaluators"].most_common():
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
print("\n【时间统计】")
if stats['date_range']:
if stats["date_range"]:
print(f" 最早评估时间: {stats['date_range']['start']}")
print(f" 最晚评估时间: {stats['date_range']['end']}")
print(f" 评估时间跨度: {stats['date_range']['duration_days']}")
else:
print(" 无时间信息")
print("\n【字段统计】")
print(f" 包含 expression_id: {'' if stats['has_expression_id'] else ''}")
print(f" 包含 reason: {'' if stats['has_reason'] else ''}")
if stats['has_reason']:
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
if stats["has_reason"]:
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
@@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]):
print(f"\n{'=' * 80}")
print("汇总统计")
print(f"{'=' * 80}")
total_files = len(all_stats)
valid_files = [s for s in all_stats if not s.get("error")]
error_files = [s for s in all_stats if s.get("error")]
print("\n【文件统计】")
print(f" 总文件数: {total_files}")
print(f" 成功解析: {len(valid_files)}")
print(f" 解析失败: {len(error_files)}")
if error_files:
print("\n 失败文件列表:")
for stats in error_files:
print(f" - {stats['file_name']}: {stats['error']}")
if not valid_files:
print("\n没有成功解析的文件")
return
# 汇总记录统计
total_records = sum(s['actual_count'] for s in valid_files)
total_suitable = sum(s['suitable_count'] for s in valid_files)
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
total_records = sum(s["actual_count"] for s in valid_files)
total_suitable = sum(s["suitable_count"] for s in valid_files)
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
total_unique_pairs = set()
# 收集所有唯一的(situation, style)对
for stats in valid_files:
try:
with open(stats['file_path'], "r", encoding="utf-8") as f:
with open(stats["file_path"], "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
for r in results:
@@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]):
total_unique_pairs.add((r["situation"], r["style"]))
except Exception:
pass
print("\n【记录汇总】")
print(f" 总记录数: {total_records:,}")
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
print(f" 通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
print(
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
if total_records > 0
else " 通过: 0 条"
)
print(
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
if total_records > 0
else " 不通过: 0 条"
)
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,}")
if total_records > 0:
duplicate_count = total_records - len(total_unique_pairs)
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
# 汇总评估者统计
all_evaluators = Counter()
for stats in valid_files:
all_evaluators.update(stats['evaluators'])
all_evaluators.update(stats["evaluators"])
print("\n【评估者汇总】")
if all_evaluators:
for evaluator, count in all_evaluators.most_common():
@@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]):
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
else:
print(" 无评估者信息")
# 汇总时间范围
all_dates = []
for stats in valid_files:
all_dates.extend(stats['evaluation_dates'])
all_dates.extend(stats["evaluation_dates"])
if all_dates:
min_date = min(all_dates)
max_date = max(all_dates)
@@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]):
print(f" 最早评估时间: {min_date.isoformat()}")
print(f" 最晚评估时间: {max_date.isoformat()}")
print(f" 总时间跨度: {(max_date - min_date).days + 1}")
# 文件大小汇总
total_size = sum(s['file_size'] for s in valid_files)
total_size = sum(s["file_size"] for s in valid_files)
avg_size = total_size / len(valid_files) if valid_files else 0
print("\n【文件大小汇总】")
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
@@ -282,35 +290,35 @@ def main():
logger.info("=" * 80)
logger.info("开始分析评估结果统计信息")
logger.info("=" * 80)
if not os.path.exists(TEMP_DIR):
print(f"\n✗ 错误未找到temp目录: {TEMP_DIR}")
logger.error(f"未找到temp目录: {TEMP_DIR}")
return
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
print(f"\n✗ 错误temp目录下未找到JSON文件: {TEMP_DIR}")
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
return
json_files.sort() # 按文件名排序
print(f"\n找到 {len(json_files)} 个JSON文件")
print("=" * 80)
# 分析每个文件
all_stats = []
for i, json_file in enumerate(json_files, 1):
stats = analyze_single_file(json_file)
all_stats.append(stats)
print_file_stats(stats, index=i)
# 打印汇总统计
print_summary(all_stats)
print(f"\n{'=' * 80}")
print("分析完成")
print(f"{'=' * 80}")
@@ -318,5 +326,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -171,7 +171,9 @@ def main():
sys.exit(1)
if not args.raw_index:
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
logger.info(
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
)
sys.exit(1)
# 解析索引列表1-based

View File

@@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(COUNT_ANALYSIS_FILE):
return [], set()
try:
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
@@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(evaluation_results: List[Dict]):
"""
保存评估结果到文件
Args:
evaluation_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(evaluation_results),
"evaluation_results": evaluation_results
"evaluation_results": evaluation_results,
}
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
except Exception as e:
@@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
print(f"\n✗ 保存评估结果失败: {e}")
def select_expressions_for_evaluation(
evaluated_pairs: Set[Tuple[str, str]] = None
) -> List[Expression]:
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
"""
选择用于评估的表达方式
选择所有count>1的项目然后选择两倍数量的count=1的项目
Args:
evaluated_pairs: 已评估的项目集合,用于避免重复
Returns:
选中的表达方式列表
"""
if evaluated_pairs is None:
evaluated_pairs = set()
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.warning("所有项目都已评估完成")
return []
# 按count分组
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
logger.info(f"未评估项目中count=1的有{len(count_eq1)}count>1的有{len(count_gt1)}")
# 选择所有count>1的项目
selected_count_gt1 = count_gt1.copy()
# 选择count=1的项目数量为count>1数量的2倍
count_gt1_count = len(selected_count_gt1)
count_eq1_needed = count_gt1_count * 2
if len(count_eq1) < count_eq1_needed:
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}")
logger.warning(
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}"
)
count_eq1_needed = len(count_eq1)
# 随机选择count=1的项目
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
selected = selected_count_gt1 + selected_count_eq1
random.shuffle(selected) # 打乱顺序
logger.info(f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍")
logger.info(
f"已选择{len(selected)}条表达方式count>1的有{len(selected_count_gt1)}全部count=1的有{len(selected_count_eq1)}2倍"
)
return selected
except Exception as e:
logger.error(f"选择表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
@@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
@@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
@@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
expression: 表达方式对象
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
logger.info(
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
)
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": expression.situation,
"style": expression.style,
@@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
"reason": reason,
"error": error,
"evaluator": "llm",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
def perform_statistical_analysis(evaluation_results: List[Dict]):
"""
对评估结果进行统计分析
Args:
evaluation_results: 评估结果列表
"""
if not evaluation_results:
print("\n没有评估结果可供分析")
return
print("\n" + "=" * 60)
print("统计分析结果")
print("=" * 60)
# 按count分组统计
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
@@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_groups[count]["suitable"] += 1
else:
count_groups[count]["unsuitable"] += 1
# 显示每个count的统计
print("\n【按count分组统计】")
print("-" * 60)
@@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
suitable = group["suitable"]
unsuitable = group["unsuitable"]
pass_rate = (suitable / total * 100) if total > 0 else 0
print(f"Count = {count}:")
print(f" 总数: {total}")
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
print()
# 比较count=1和count>1
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
for result in evaluation_results:
count = result.get("count", 1)
suitable = result.get("suitable", False)
if count == 1:
count_eq1_group["total"] += 1
if suitable:
@@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
count_gt1_group["suitable"] += 1
else:
count_gt1_group["unsuitable"] += 1
print("\n【Count=1 vs Count>1 对比】")
print("-" * 60)
eq1_total = count_eq1_group["total"]
eq1_suitable = count_eq1_group["suitable"]
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
gt1_total = count_gt1_group["total"]
gt1_suitable = count_gt1_group["suitable"]
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
print("Count = 1:")
print(f" 总数: {eq1_total}")
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
print()
print("Count > 1:")
print(f" 总数: {gt1_total}")
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
print()
# 进行卡方检验简化版使用2x2列联表
if eq1_total > 0 and gt1_total > 0:
print("【统计显著性检验】")
print("-" * 60)
# 构建2x2列联表
# 通过 不通过
# count=1 a b
@@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
b = eq1_total - eq1_suitable
c = gt1_suitable
d = gt1_total - gt1_suitable
# 计算卡方统计量简化版使用Pearson卡方检验
n = eq1_total + gt1_total
if n > 0:
@@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
e_b = (eq1_total * (b + d)) / n
e_c = (gt1_total * (a + c)) / n
e_d = (gt1_total * (b + d)) / n
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5
min_expected = min(e_a, e_b, e_c, e_d)
if min_expected < 5:
print("警告期望频数小于5卡方检验可能不准确")
print("建议使用Fisher精确检验")
# 计算卡方值
chi_square = 0
if e_a > 0:
@@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
chi_square += ((c - e_c) ** 2) / e_c
if e_d > 0:
chi_square += ((d - e_d) ** 2) / e_d
# 自由度 = (行数-1) * (列数-1) = 1
df = 1
# 临界值(α=0.05
chi_square_critical_005 = 3.841
chi_square_critical_001 = 6.635
print(f"卡方统计量: {chi_square:.4f}")
print(f"自由度: {df}")
print(f"临界值 (α=0.05): {chi_square_critical_005}")
print(f"临界值 (α=0.01): {chi_square_critical_001}")
if chi_square >= chi_square_critical_001:
print("结论: 在α=0.01水平下count=1和count>1的合格率存在显著差异p<0.01")
elif chi_square >= chi_square_critical_005:
print("结论: 在α=0.05水平下count=1和count>1的合格率存在显著差异p<0.05")
else:
print("结论: 在α=0.05水平下count=1和count>1的合格率不存在显著差异p≥0.05")
# 计算差异大小
diff = abs(eq1_pass_rate - gt1_pass_rate)
print(f"\n合格率差异: {diff:.2f}%")
@@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
print("数据不足,无法进行统计检验")
else:
print("数据不足无法进行count=1和count>1的对比分析")
# 保存统计分析结果
analysis_result = {
"analysis_time": datetime.now().isoformat(),
"count_groups": {str(k): v for k, v in count_groups.items()},
"count_eq1": count_eq1_group,
"count_gt1": count_gt1_group,
"total_evaluated": len(evaluation_results)
"total_evaluated": len(evaluation_results),
}
try:
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
with open(analysis_file, "w", encoding="utf-8") as f:
@@ -444,7 +444,7 @@ async def main():
logger.info("=" * 60)
logger.info("开始表达方式按count分组的LLM评估和统计分析")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
@@ -452,97 +452,95 @@ async def main():
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
evaluation_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
# 检查是否需要继续评估检查是否还有未评估的count>1项目
# 先查询未评估的count>1项目数量
try:
all_expressions = list(Expression.select())
unevaluated_count_gt1 = [
expr for expr in all_expressions
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
]
has_unevaluated = len(unevaluated_count_gt1) > 0
except Exception as e:
logger.error(f"查询未评估项目失败: {e}")
has_unevaluated = False
if has_unevaluated:
print("\n" + "=" * 60)
print("开始LLM评估")
print("=" * 60)
print("评估结果会自动保存到文件\n")
# 创建LLM实例
print("创建LLM实例...")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_count_analysis_llm"
request_type="expression_evaluator_count_analysis_llm",
)
print("✓ LLM实例创建成功\n")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
print(f"\n✗ 创建LLM实例失败: {e}")
db.close()
return
# 选择需要评估的表达方式选择所有count>1的项目然后选择两倍数量的count=1的项目
expressions = select_expressions_for_evaluation(
evaluated_pairs=evaluated_pairs
)
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
if not expressions:
print("\n没有可评估的项目")
else:
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)}")
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)}\n")
batch_results = []
for i, expression in enumerate(expressions, 1):
print(f"LLM评估进度: {i}/{len(expressions)}")
print(f" Situation: {expression.situation}")
print(f" Style: {expression.style}")
print(f" Count: {expression.count}")
llm_result = await llm_evaluate_expression(expression, llm)
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
if llm_result.get('error'):
if llm_result.get("error"):
print(f" 错误: {llm_result['error']}")
print()
batch_results.append(llm_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
# 添加延迟以避免API限流
await asyncio.sleep(0.3)
# 将当前批次结果添加到总结果中
evaluation_results.extend(batch_results)
# 保存结果
save_results(evaluation_results)
else:
print(f"\n所有count>1的项目都已评估完成已有 {len(evaluation_results)} 条评估结果")
# 进行统计分析
if len(evaluation_results) > 0:
perform_statistical_analysis(evaluation_results)
else:
print("\n没有评估结果可供分析")
# 关闭数据库连接
try:
db.close()
@@ -553,4 +551,3 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
def load_manual_results() -> List[Dict]:
"""
加载人工评估结果自动读取temp目录下所有JSON文件并合并
Returns:
人工评估结果列表(已去重)
"""
@@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]:
print("\n✗ 错误未找到temp目录")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
# 查找所有JSON文件
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
if not json_files:
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
print("\n✗ 错误temp目录下未找到JSON文件")
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
return []
logger.info(f"找到 {len(json_files)} 个JSON文件")
print(f"\n找到 {len(json_files)} 个JSON文件:")
for json_file in json_files:
print(f" - {os.path.basename(json_file)}")
# 读取并合并所有JSON文件
all_results = []
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
for json_file in json_files:
try:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
results = data.get("manual_results", [])
# 去重:使用(situation, style)作为唯一标识
for result in results:
if "situation" not in result or "style" not in result:
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
continue
pair = (result["situation"], result["style"])
if pair not in seen_pairs:
seen_pairs.add(pair)
all_results.append(result)
logger.info(f"{os.path.basename(json_file)} 加载了 {len(results)} 条结果")
except Exception as e:
logger.error(f"加载文件 {json_file} 失败: {e}")
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
continue
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
return all_results
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
@@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
return prompt
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
@@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
"""
使用LLM评估单个表达方式
Args:
situation: 情境
style: 风格
llm: LLM请求实例
Returns:
评估结果字典
"""
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
if error:
suitable = False
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
return {
"situation": situation,
"style": style,
"suitable": suitable,
"reason": reason,
"error": error,
"evaluator": "llm"
"evaluator": "llm",
}
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
"""
对比人工评估和LLM评估的结果
Args:
manual_results: 人工评估结果列表
llm_results: LLM评估结果列表
method_name: 评估方法名称(用于标识)
Returns:
对比分析结果字典
"""
# 按(situation, style)建立映射
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
total = len(manual_results)
matched = 0
true_positives = 0
true_negatives = 0
false_positives = 0
false_negatives = 0
for manual_result in manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
manual_suitable = manual_result["suitable"]
llm_suitable = llm_result["suitable"]
if manual_suitable == llm_suitable:
matched += 1
if manual_suitable and llm_suitable:
true_positives += 1
elif not manual_suitable and not llm_suitable:
@@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
false_positives += 1
elif manual_suitable and not llm_suitable:
false_negatives += 1
accuracy = (matched / total * 100) if total > 0 else 0
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
precision = (
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
)
recall = (
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
)
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
specificity = (
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
)
# 计算人工效标的不合适率
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
# 计算经过LLM删除后剩余项目中的不合适率
# 在所有项目中移除LLM判定为不合适的项目后剩下的项目 = TP + FPLLM判定为合适的项目
# 在这些剩下的项目中,按人工评定的不合适项目 = FP人工认为不合适但LLM认为合适
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数保留的项目
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
# 两者百分比相减评估LLM评定修正后的不合适率是否有降低
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
random_baseline = 50.0
accuracy_above_random = accuracy - random_baseline
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
return {
"method": method_name,
"total": total,
@@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
"specificity": specificity,
"manual_unsuitable_rate": manual_unsuitable_rate,
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
"rate_difference": rate_difference
"rate_difference": rate_difference,
}
async def main(count: int | None = None):
"""
主函数
Args:
count: 随机选取的数据条数如果为None则使用全部数据
"""
logger.info("=" * 60)
logger.info("开始表达方式LLM评估")
logger.info("=" * 60)
# 1. 加载人工评估结果
print("\n步骤1: 加载人工评估结果")
manual_results = load_manual_results()
if not manual_results:
return
print(f"成功加载 {len(manual_results)} 条人工评估结果")
# 如果指定了数量,随机选择指定数量的数据
if count is not None:
if count <= 0:
@@ -317,7 +322,7 @@ async def main(count: int | None = None):
random.seed() # 使用系统时间作为随机种子
manual_results = random.sample(manual_results, count)
print(f"随机选取 {len(manual_results)} 条数据进行评估")
# 验证数据完整性
valid_manual_results = []
for r in manual_results:
@@ -325,62 +330,58 @@ async def main(count: int | None = None):
valid_manual_results.append(r)
else:
logger.warning(f"跳过无效数据: {r}")
if len(valid_manual_results) != len(manual_results):
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
print(f"有效数据: {len(valid_manual_results)}")
# 2. 创建LLM实例并评估
print("\n步骤2: 创建LLM实例")
try:
llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_evaluator_llm"
)
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
import traceback
logger.error(traceback.format_exc())
return
print("\n步骤3: 开始LLM评估")
llm_results = []
for i, manual_result in enumerate(valid_manual_results, 1):
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
llm_results.append(await evaluate_expression_llm(
manual_result["situation"],
manual_result["style"],
llm
))
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
await asyncio.sleep(0.3)
# 5. 输出FP和FN项目在评估结果之前
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
# 5.1 输出FP项目人工评估不通过但LLM误判为通过
print("\n" + "=" * 60)
print("人工评估不通过但LLM误判为通过的项目FP - False Positive")
print("=" * 60)
fp_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估不通过但LLM评估通过FP情况
if not manual_result["suitable"] and llm_result["suitable"]:
fp_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fp_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fp_items:
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
for idx, item in enumerate(fp_items, 1):
@@ -389,36 +390,38 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 不通过 ❌")
print("LLM评估: 通过 ✅ (误判)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误判项目所有人工评估不通过的项目都被LLM正确识别为不通过")
# 5.2 输出FN项目人工评估通过但LLM误判为不通过
print("\n" + "=" * 60)
print("人工评估通过但LLM误判为不通过的项目FN - False Negative")
print("=" * 60)
fn_items = []
for manual_result in valid_manual_results:
pair = (manual_result["situation"], manual_result["style"])
llm_result = llm_dict.get(pair)
if llm_result is None:
continue
# 人工评估通过但LLM评估不通过FN情况
if manual_result["suitable"] and not llm_result["suitable"]:
fn_items.append({
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error")
})
fn_items.append(
{
"situation": manual_result["situation"],
"style": manual_result["style"],
"manual_suitable": manual_result["suitable"],
"llm_suitable": llm_result["suitable"],
"llm_reason": llm_result.get("reason", "未提供理由"),
"llm_error": llm_result.get("error"),
}
)
if fn_items:
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
for idx, item in enumerate(fn_items, 1):
@@ -427,33 +430,41 @@ async def main(count: int | None = None):
print(f"Style: {item['style']}")
print("人工评估: 通过 ✅")
print("LLM评估: 不通过 ❌ (误删)")
if item.get('llm_error'):
if item.get("llm_error"):
print(f"LLM错误: {item['llm_error']}")
print(f"LLM理由: {item['llm_reason']}")
print()
else:
print("\n✓ 没有误删项目所有人工评估通过的项目都被LLM正确识别为通过")
# 6. 对比分析并输出结果
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
print("\n" + "=" * 60)
print("评估结果(以人工评估为标准)")
print("=" * 60)
# 详细评估结果(核心指标优先)
print(f"\n--- {comparison['method']} ---")
print(f" 总数: {comparison['total']}")
print()
# print(" 【核心能力指标】")
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}")
print(
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
print()
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}")
print(
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
)
print(
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']}"
)
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
print()
print(" 【其他指标】")
@@ -464,12 +475,18 @@ async def main(count: int | None = None):
print()
print(" 【不合适率分析】")
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
print(
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
)
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
print()
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
print(f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
print(
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
)
print(
f" - 含义: 在所有项目中移除LLM判定为不合适的项目后在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
)
print()
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
@@ -480,21 +497,22 @@ async def main(count: int | None = None):
print(f" TN (正确识别为不合适): {comparison['true_negatives']}")
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
# 7. 保存结果到JSON文件
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump({
"manual_results": valid_manual_results,
"llm_results": llm_results,
"comparison": comparison
}, f, ensure_ascii=False, indent=2)
json.dump(
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
f,
ensure_ascii=False,
indent=2,
)
logger.info(f"\n评估结果已保存到: {output_file}")
except Exception as e:
logger.warning(f"保存结果到文件失败: {e}")
print("\n" + "=" * 60)
print("评估完成")
print("=" * 60)
@@ -509,15 +527,9 @@ if __name__ == "__main__":
python evaluate_expressions_llm_v6.py # 使用全部数据
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
"""
""",
)
parser.add_argument(
"-n", "--count",
type=int,
default=None,
help="随机选取的数据条数(默认:使用全部数据)"
)
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
args = parser.parse_args()
asyncio.run(main(count=args.count))

View File

@@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
"""
加载已有的评估结果
Returns:
(已有结果列表, 已评估的项目(situation, style)元组集合)
"""
if not os.path.exists(MANUAL_EVAL_FILE):
return [], set()
try:
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
@@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
def save_results(manual_results: List[Dict]):
"""
保存评估结果到文件
Args:
manual_results: 评估结果列表
"""
try:
os.makedirs(TEMP_DIR, exist_ok=True)
data = {
"last_updated": datetime.now().isoformat(),
"total_count": len(manual_results),
"manual_results": manual_results
"manual_results": manual_results,
}
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
except Exception as e:
@@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]):
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
"""
获取未评估的表达方式
Args:
evaluated_pairs: 已评估的项目(situation, style)元组集合
batch_size: 每次获取的数量
Returns:
未评估的表达方式列表
"""
try:
# 查询所有表达方式
all_expressions = list(Expression.select())
if not all_expressions:
logger.warning("数据库中没有表达方式记录")
return []
# 过滤出未评估的项目:匹配 situation 和 style 均一致
unevaluated = [
expr for expr in all_expressions
if (expr.situation, expr.style) not in evaluated_pairs
]
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
if not unevaluated:
logger.info("所有项目都已评估完成")
return []
# 如果未评估数量少于请求数量,返回所有
if len(unevaluated) <= batch_size:
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
return unevaluated
# 随机选择指定数量
selected = random.sample(unevaluated, batch_size)
logger.info(f"{len(unevaluated)} 条未评估项目中随机选择了 {len(selected)}")
return selected
except Exception as e:
logger.error(f"获取未评估表达方式失败: {e}")
import traceback
logger.error(traceback.format_exc())
return []
@@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
"""
人工评估单个表达方式
Args:
expression: 表达方式对象
index: 当前索引从1开始
total: 总数
Returns:
评估结果字典,如果用户退出则返回 None
"""
@@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
print(" 输入 'n''no''0' 表示不合适(不通过)")
print(" 输入 'q''quit' 退出评估")
print(" 输入 's''skip' 跳过当前项目")
while True:
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
if user_input in ['q', 'quit']:
if user_input in ["q", "quit"]:
print("退出评估")
return None
if user_input in ['s', 'skip']:
if user_input in ["s", "skip"]:
print("跳过当前项目")
return "skip"
if user_input in ['y', 'yes', '1', '', '通过']:
if user_input in ["y", "yes", "1", "", "通过"]:
suitable = True
break
elif user_input in ['n', 'no', '0', '', '不通过']:
elif user_input in ["n", "no", "0", "", "不通过"]:
suitable = False
break
else:
print("输入无效,请重新输入 (y/n/q/s)")
result = {
"situation": expression.situation,
"style": expression.style,
"suitable": suitable,
"reason": None,
"evaluator": "manual",
"evaluated_at": datetime.now().isoformat()
"evaluated_at": datetime.now().isoformat(),
}
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
return result
@@ -186,7 +184,7 @@ def main():
logger.info("=" * 60)
logger.info("开始表达方式人工评估")
logger.info("=" * 60)
# 初始化数据库连接
try:
db.connect(reuse_if_open=True)
@@ -194,41 +192,41 @@ def main():
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
# 加载已有评估结果
existing_results, evaluated_pairs = load_existing_results()
manual_results = existing_results.copy()
if evaluated_pairs:
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
print(f"已评估项目数: {len(evaluated_pairs)}")
print("\n" + "=" * 60)
print("开始人工评估")
print("=" * 60)
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
print("评估结果会自动保存到文件\n")
batch_size = 10
batch_count = 0
while True:
# 获取未评估的项目
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
if not expressions:
print("\n" + "=" * 60)
print("所有项目都已评估完成!")
print("=" * 60)
break
batch_count += 1
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
batch_results = []
for i, expression in enumerate(expressions, 1):
manual_result = manual_evaluate_expression(expression, i, len(expressions))
if manual_result is None:
# 用户退出
print("\n评估已中断")
@@ -237,34 +235,34 @@ def main():
manual_results.extend(batch_results)
save_results(manual_results)
return
if manual_result == "skip":
# 跳过当前项目
continue
batch_results.append(manual_result)
# 使用 (situation, style) 作为唯一标识
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
# 将当前批次结果添加到总结果中
manual_results.extend(batch_results)
# 保存结果
save_results(manual_results)
print(f"\n当前批次完成,已评估总数: {len(manual_results)}")
# 询问是否继续
while True:
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
if continue_input in ['y', 'yes', '1', '', '继续']:
if continue_input in ["y", "yes", "1", "", "继续"]:
break
elif continue_input in ['n', 'no', '0', '', '退出']:
elif continue_input in ["n", "no", "0", "", "退出"]:
print("\n评估结束")
return
else:
print("输入无效,请重新输入 (y/n)")
# 关闭数据库连接
try:
db.close()
@@ -275,4 +273,3 @@ def main():
if __name__ == "__main__":
main()

View File

@@ -134,9 +134,7 @@ def handle_import_openie(
# 在非交互模式下,不再询问用户,而是直接报错终止
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
if non_interactive:
logger.error(
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
)
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
sys.exit(1)
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower()
@@ -189,9 +187,7 @@ def handle_import_openie(
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
# 新增确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
)
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
else:
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
@@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
def main(argv: Optional[list[str]] = None) -> None:
"""主函数 - 解析参数并运行异步主流程。"""
parser = argparse.ArgumentParser(
description=(
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
"将其导入到 LPMM 的向量库与知识图中。"
)
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
)
parser.add_argument(
"--non-interactive",

View File

@@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
ensure_dirs() # 确保目录存在
# 新增用户确认提示
if non_interactive:
logger.warning(
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
)
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
else:
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@@ -68,4 +68,3 @@ def main() -> None:
if __name__ == "__main__":
main()

View File

@@ -29,6 +29,7 @@ except ImportError as e:
logger = get_logger("lpmm_interactive_manager")
async def interactive_add():
"""交互式导入知识"""
print("\n" + "=" * 40)
@@ -38,7 +39,7 @@ async def interactive_add():
print(" - 支持多段落,段落间请保留空行。")
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
print("-" * 40)
lines = []
while True:
try:
@@ -48,7 +49,7 @@ async def interactive_add():
lines.append(line)
except EOFError:
break
text = "\n".join(lines).strip()
if not text:
print("\n[!] 内容为空,操作已取消。")
@@ -58,7 +59,7 @@ async def interactive_add():
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.add_content(text)
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
print(f" 实际新增段落数: {result.get('count', 0)}")
@@ -68,6 +69,7 @@ async def interactive_add():
print(f"\n[×] 发生异常: {e}")
logger.error(f"add_content 异常: {e}", exc_info=True)
async def interactive_delete():
"""交互式删除知识"""
print("\n" + "=" * 40)
@@ -77,10 +79,10 @@ async def interactive_delete():
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
print(" 2. 完整文段匹配(删除完全匹配的段落)")
print("-" * 40)
mode = input("请选择删除模式 (1/2): ").strip()
exact_match = False
if mode == "2":
exact_match = True
print("\n[完整文段匹配模式]")
@@ -102,14 +104,18 @@ async def interactive_delete():
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
print("\n[关键词模糊匹配模式]")
keyword = input("请输入匹配关键词: ").strip()
if not keyword:
print("\n[!] 输入为空,操作已取消。")
return
print("-" * 40)
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
if confirm != 'y':
confirm = (
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
.strip()
.lower()
)
if confirm != "y":
print("\n[!] 已取消删除操作。")
return
@@ -117,7 +123,7 @@ async def interactive_delete():
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.delete(keyword, exact_match=exact_match)
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
print(f" 删除条数: {result.get('deleted_count', 0)}")
@@ -129,6 +135,7 @@ async def interactive_delete():
print(f"\n[×] 发生异常: {e}")
logger.error(f"delete 异常: {e}", exc_info=True)
async def interactive_clear():
"""交互式清空知识库"""
print("\n" + "=" * 40)
@@ -141,40 +148,45 @@ async def interactive_clear():
print(" - 整个知识图谱")
print(" - 此操作不可恢复!")
print("-" * 40)
# 双重确认
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
if confirm1 != "YES":
print("\n[!] 已取消清空操作。")
return
print("\n" + "=" * 40)
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
if confirm2 != "CLEAR":
print("\n[!] 已取消清空操作。")
return
print("\n[进度] 正在清空知识库...")
try:
# 使用 lpmm_ops.py 中的接口
result = await lpmm_ops.clear_all()
if result["status"] == "success":
print(f"\n[√] 成功:{result['message']}")
stats = result.get("stats", {})
before = stats.get("before", {})
after = stats.get("after", {})
print("\n[统计信息]")
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
print(
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
)
print(
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
)
else:
print(f"\n[×] 失败:{result['message']}")
except Exception as e:
print(f"\n[×] 发生异常: {e}")
logger.error(f"clear_all 异常: {e}", exc_info=True)
async def interactive_search():
"""交互式查询知识"""
print("\n" + "=" * 40)
@@ -182,25 +194,25 @@ async def interactive_search():
print("=" * 40)
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
print("-" * 40)
# 确保 LPMM 已初始化
if not global_config.lpmm_knowledge.enable:
print("\n[!] 警告LPMM 知识库在配置中未启用。")
return
try:
lpmm_start_up()
except Exception as e:
print(f"\n[!] LPMM 初始化失败: {e}")
logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
return
query = input("请输入查询问题或关键词: ").strip()
if not query:
print("\n[!] 查询内容为空,操作已取消。")
return
# 询问返回条数
print("-" * 40)
limit_str = input("希望返回的相关知识条数默认3直接回车使用默认值: ").strip()
@@ -210,11 +222,11 @@ async def interactive_search():
except ValueError:
limit = 3
print("[!] 输入无效,使用默认值 3。")
print("\n[进度] 正在查询知识库...")
try:
result = await query_lpmm_knowledge(query, limit=limit)
print("\n" + "=" * 60)
print("[查询结果]")
print("=" * 60)
@@ -224,6 +236,7 @@ async def interactive_search():
print(f"\n[×] 查询失败: {e}")
logger.error(f"查询异常: {e}", exc_info=True)
async def main():
"""主循环"""
while True:
@@ -236,9 +249,9 @@ async def main():
print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
print("║ 0. 退出 (Exit) ║")
print("" + "" * 38 + "")
choice = input("请选择操作编号: ").strip()
if choice == "1":
await interactive_add()
elif choice == "2":
@@ -253,6 +266,7 @@ async def main():
else:
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
if __name__ == "__main__":
try:
# 运行主循环
@@ -262,4 +276,3 @@ if __name__ == "__main__":
except Exception as e:
print(f"\n[!] 程序运行出错: {e}")
logger.error(f"Main loop 异常: {e}", exc_info=True)

View File

@@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
txt_files = list(raw_dir.glob("*.txt"))
if not txt_files:
msg = (
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
"info_extraction 可能立即退出或无数据可处理。"
)
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件info_extraction 可能立即退出或无数据可处理。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
)
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
return False
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
return cont == "y"
@@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
json_files = list(openie_dir.glob("*.json"))
if not json_files:
msg = (
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
"import_openie 可能会因为找不到批次而失败。"
)
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件import_openie 可能会因为找不到批次而失败。"
print(msg)
if non_interactive:
logger.error(
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
)
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
return False
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
return cont == "y"
@@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
try:
if not getattr(global_config.lpmm_knowledge, "enable", False):
print(
"[WARN] 当前配置 lpmm_knowledge.enable = false"
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
)
print("[WARN] 当前配置 lpmm_knowledge.enable = false刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
except Exception:
# 配置异常时不阻断主流程,仅忽略提示
pass
@@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
if action == "prepare_raw":
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
sha_list, raw_data = load_raw_data()
print(
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
elif action == "info_extract":
if not _check_before_info_extract("--non-interactive" in extra_args):
print("已根据用户选择,取消执行信息提取。")
@@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
logger.info("开始 full_import预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
sha_list, raw_data = load_raw_data()
print(
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
f"去重后哈希数 {len(sha_list)}"
)
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}")
non_interactive = "--non-interactive" in extra_args
if not _check_before_info_extract(non_interactive):
print("已根据用户选择,取消 full_import信息提取阶段被取消")
@@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
)
# 快速选项:按推荐方式清理所有相关实体/关系
quick_all = input(
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
).strip().lower()
quick_all = (
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
)
if quick_all in ("", "y", "yes"):
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
else:
@@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
def _interactive_build_batch_inspect_args() -> List[str]:
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
path = _interactive_choose_openie_file(
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
)
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
if not path:
return []
return ["--openie-file", path]
@@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
def _interactive_build_test_args() -> List[str]:
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
print(
"\n[TEST] 你可以:\n"
"- 直接回车使用内置的默认测试用例;\n"
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
)
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
if not query:
return []
@@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
new_dim = input(
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
).strip()
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
if new_dim and not new_dim.isdigit():
print("输入的维度不是纯数字,已取消操作。")
return
@@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
if __name__ == "__main__":
main()

View File

@@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval")
# 使用 importlib 动态导入,避免循环导入问题
def _import_memory_retrieval():
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
try:
# 先导入 prompt_builder检查 prompt 是否已经初始化
from src.chat.utils.prompt_builder import global_prompt_manager
# 检查 memory_retrieval 相关的 prompt 是否已经注册
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
module_name = "src.memory_system.memory_retrieval"
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
if prompt_already_init and module_name in sys.modules:
existing_module = sys.modules[module_name]
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
if hasattr(existing_module, "init_memory_retrieval_prompt"):
return (
existing_module.init_memory_retrieval_prompt,
existing_module._react_agent_solve_question,
existing_module._process_single_question,
)
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
if module_name in sys.modules:
existing_module = sys.modules[module_name]
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
# 模块部分初始化,移除它
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
del sys.modules[module_name]
# 清理可能相关的部分初始化模块
keys_to_remove = []
for key in sys.modules.keys():
if key.startswith('src.memory_system.') and key != 'src.memory_system':
if key.startswith("src.memory_system.") and key != "src.memory_system":
keys_to_remove.append(key)
for key in keys_to_remove:
try:
del sys.modules[key]
except KeyError:
pass
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
try:
# 先导入可能触发循环导入的模块,让它们完成初始化
import src.config.config
import src.chat.utils.prompt_builder
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
# 如果它们已经导入,就确保它们完全初始化
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval
@@ -89,11 +91,11 @@ def _import_memory_retrieval():
pass # 如果导入失败,继续
except Exception as e:
logger.warning(f"预加载依赖模块时出现警告: {e}")
# 现在尝试导入 memory_retrieval
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
memory_retrieval_module = importlib.import_module(module_name)
return (
memory_retrieval_module.init_memory_retrieval_prompt,
memory_retrieval_module._react_agent_solve_question,
@@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
"""获取从指定时间开始的token使用情况
Args:
start_time: 开始时间戳
Returns:
包含token使用统计的字典
"""
try:
start_datetime = datetime.fromtimestamp(start_time)
# 查询从开始时间到现在的所有memory相关的token使用记录
records = (
LLMUsage.select()
@@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
)
.order_by(LLMUsage.timestamp.asc())
)
total_prompt_tokens = 0
total_completion_tokens = 0
total_tokens = 0
total_cost = 0.0
request_count = 0
model_usage = {} # 按模型统计
for record in records:
total_prompt_tokens += record.prompt_tokens or 0
total_completion_tokens += record.completion_tokens or 0
total_tokens += record.total_tokens or 0
total_cost += record.cost or 0.0
request_count += 1
# 按模型统计
model_name = record.model_name or "unknown"
if model_name not in model_usage:
@@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
model_usage[model_name]["cost"] += record.cost or 0.0
model_usage[model_name]["request_count"] += 1
return {
"total_prompt_tokens": total_prompt_tokens,
"total_completion_tokens": total_completion_tokens,
@@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str:
"""格式化思考步骤为可读字符串"""
if not thinking_steps:
return "无思考步骤"
lines = []
for step in thinking_steps:
iteration = step.get("iteration", "?")
thought = step.get("thought", "")
actions = step.get("actions", [])
observations = step.get("observations", [])
lines.append(f"\n--- 迭代 {iteration} ---")
if thought:
lines.append(f"思考: {thought[:200]}...")
if actions:
lines.append("行动:")
for action in actions:
action_type = action.get("action_type", "unknown")
action_params = action.get("action_params", {})
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
if observations:
lines.append("观察:")
for obs in observations:
@@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str:
if len(str(obs)) > 200:
obs_str += "..."
lines.append(f" - {obs_str}")
return "\n".join(lines)
@@ -242,31 +244,32 @@ async def test_memory_retrieval(
max_iterations: Optional[int] = None,
) -> Dict[str, Any]:
"""测试记忆检索功能
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
max_iterations: 最大迭代次数
Returns:
包含测试结果的字典
"""
print("\n" + "=" * 80)
print(f"[测试] 记忆检索测试")
print("[测试] 记忆检索测试")
print(f"[问题] {question}")
print("=" * 80)
# 记录开始时间
start_time = time.time()
# 延迟导入并初始化记忆检索prompt这会自动加载 global_config
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
try:
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
# 检查 prompt 是否已经初始化,避免重复初始化
from src.chat.utils.prompt_builder import global_prompt_manager
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
init_memory_retrieval_prompt()
else:
@@ -274,24 +277,24 @@ async def test_memory_retrieval(
except Exception as e:
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
raise
# 获取 global_config此时应该已经加载
from src.config.config import global_config
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
timeout = global_config.memory.agent_timeout_seconds
print(f"\n[配置]")
print("\n[配置]")
print(f" 最大迭代次数: {max_iterations}")
print(f" 超时时间: {timeout}")
print(f" 聊天ID: {chat_id}")
# 执行检索
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
@@ -299,14 +302,14 @@ async def test_memory_retrieval(
timeout=timeout,
initial_info="",
)
# 记录结束时间
end_time = time.time()
elapsed_time = end_time - start_time
# 获取token使用情况
token_usage = get_token_usage_since(start_time)
# 构建结果
result = {
"question": question,
@@ -318,41 +321,41 @@ async def test_memory_retrieval(
"iteration_count": len(thinking_steps),
"token_usage": token_usage,
}
# 输出结果
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
print(f"\n[结果]")
print("\n[结果]")
print(f" 是否找到答案: {'' if found_answer else ''}")
if found_answer and answer:
print(f" 答案: {answer}")
else:
print(f" 答案: (未找到答案)")
print(" 答案: (未找到答案)")
print(f" 是否超时: {'' if is_timeout else ''}")
print(f" 迭代次数: {len(thinking_steps)}")
print(f" 总耗时: {elapsed_time:.2f}")
print(f"\n[Token使用情况]")
print("\n[Token使用情况]")
print(f" 总请求数: {token_usage['request_count']}")
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
print(f" 总Tokens: {token_usage['total_tokens']:,}")
print(f" 总成本: ${token_usage['total_cost']:.6f}")
if token_usage['model_usage']:
print(f"\n[按模型统计]")
for model_name, usage in token_usage['model_usage'].items():
if token_usage["model_usage"]:
print("\n[按模型统计]")
for model_name, usage in token_usage["model_usage"].items():
print(f" {model_name}:")
print(f" 请求数: {usage['request_count']}")
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
print(f" Completion Tokens: {usage['completion_tokens']:,}")
print(f" 总Tokens: {usage['total_tokens']:,}")
print(f" 成本: ${usage['cost']:.6f}")
print(f"\n[迭代详情]")
print("\n[迭代详情]")
print(format_thinking_steps(thinking_steps))
print("\n" + "=" * 80)
return result
@@ -375,12 +378,12 @@ def main() -> None:
"-o",
help="将结果保存到JSON文件可选",
)
args = parser.parse_args()
# 初始化日志(使用较低的详细程度,避免输出过多日志)
initialize_logging(verbose=False)
# 交互式输入问题
print("\n" + "=" * 80)
print("记忆检索测试工具")
@@ -389,7 +392,7 @@ def main() -> None:
if not question:
print("错误: 问题不能为空")
return
# 交互式输入最大迭代次数
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
max_iterations = None
@@ -402,7 +405,7 @@ def main() -> None:
except ValueError:
print("警告: 无效的迭代次数,将使用配置默认值")
max_iterations = None
# 连接数据库
try:
db.connect(reuse_if_open=True)
@@ -410,7 +413,7 @@ def main() -> None:
logger.error(f"数据库连接失败: {e}")
print(f"错误: 数据库连接失败: {e}")
return
# 运行测试
try:
result = asyncio.run(
@@ -421,7 +424,7 @@ def main() -> None:
max_iterations=max_iterations,
)
)
# 如果指定了输出文件,保存结果
if args.output:
# 将thinking_steps转换为可序列化的格式
@@ -429,7 +432,7 @@ def main() -> None:
with open(args.output, "w", encoding="utf-8") as f:
json.dump(output_result, f, ensure_ascii=False, indent=2)
print(f"\n[结果已保存] {args.output}")
except KeyboardInterrupt:
print("\n\n[中断] 用户中断测试")
except Exception as e:
@@ -444,4 +447,3 @@ def main() -> None:
if __name__ == "__main__":
main()