Ruff Format
This commit is contained in:
@@ -38,10 +38,10 @@ def parse_datetime(dt_str: str) -> datetime | None:
|
||||
def analyze_single_file(file_path: str) -> Dict:
|
||||
"""
|
||||
分析单个JSON文件的统计信息
|
||||
|
||||
|
||||
Args:
|
||||
file_path: JSON文件路径
|
||||
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
@@ -65,40 +65,40 @@ def analyze_single_file(file_path: str) -> Dict:
|
||||
"has_reason": False,
|
||||
"reason_count": 0,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
# 基本信息
|
||||
stats["last_updated"] = data.get("last_updated")
|
||||
stats["total_count"] = data.get("total_count", 0)
|
||||
|
||||
|
||||
results = data.get("manual_results", [])
|
||||
stats["actual_count"] = len(results)
|
||||
|
||||
|
||||
if not results:
|
||||
return stats
|
||||
|
||||
|
||||
# 统计通过/不通过
|
||||
suitable_count = sum(1 for r in results if r.get("suitable") is True)
|
||||
unsuitable_count = sum(1 for r in results if r.get("suitable") is False)
|
||||
stats["suitable_count"] = suitable_count
|
||||
stats["unsuitable_count"] = unsuitable_count
|
||||
stats["suitable_rate"] = (suitable_count / len(results) * 100) if results else 0.0
|
||||
|
||||
|
||||
# 统计唯一的(situation, style)对
|
||||
pairs: Set[Tuple[str, str]] = set()
|
||||
for r in results:
|
||||
if "situation" in r and "style" in r:
|
||||
pairs.add((r["situation"], r["style"]))
|
||||
stats["unique_pairs"] = len(pairs)
|
||||
|
||||
|
||||
# 统计评估者
|
||||
for r in results:
|
||||
evaluator = r.get("evaluator", "unknown")
|
||||
stats["evaluators"][evaluator] += 1
|
||||
|
||||
|
||||
# 统计评估时间
|
||||
evaluation_dates = []
|
||||
for r in results:
|
||||
@@ -107,7 +107,7 @@ def analyze_single_file(file_path: str) -> Dict:
|
||||
dt = parse_datetime(evaluated_at)
|
||||
if dt:
|
||||
evaluation_dates.append(dt)
|
||||
|
||||
|
||||
stats["evaluation_dates"] = evaluation_dates
|
||||
if evaluation_dates:
|
||||
min_date = min(evaluation_dates)
|
||||
@@ -115,18 +115,18 @@ def analyze_single_file(file_path: str) -> Dict:
|
||||
stats["date_range"] = {
|
||||
"start": min_date.isoformat(),
|
||||
"end": max_date.isoformat(),
|
||||
"duration_days": (max_date - min_date).days + 1
|
||||
"duration_days": (max_date - min_date).days + 1,
|
||||
}
|
||||
|
||||
|
||||
# 检查字段存在性
|
||||
stats["has_expression_id"] = any("expression_id" in r for r in results)
|
||||
stats["has_reason"] = any(r.get("reason") for r in results)
|
||||
stats["reason_count"] = sum(1 for r in results if r.get("reason"))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
stats["error"] = str(e)
|
||||
logger.error(f"分析文件 {file_name} 时出错: {e}")
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@@ -136,57 +136,57 @@ def print_file_stats(stats: Dict, index: int = None):
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"{prefix}文件: {stats['file_name']}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
if stats["error"]:
|
||||
print(f"✗ 错误: {stats['error']}")
|
||||
return
|
||||
|
||||
|
||||
print(f"文件路径: {stats['file_path']}")
|
||||
print(f"文件大小: {stats['file_size']:,} 字节 ({stats['file_size'] / 1024:.2f} KB)")
|
||||
|
||||
|
||||
if stats["last_updated"]:
|
||||
print(f"最后更新: {stats['last_updated']}")
|
||||
|
||||
|
||||
print("\n【记录统计】")
|
||||
print(f" 文件中的 total_count: {stats['total_count']}")
|
||||
print(f" 实际记录数: {stats['actual_count']}")
|
||||
|
||||
if stats['total_count'] != stats['actual_count']:
|
||||
diff = stats['total_count'] - stats['actual_count']
|
||||
|
||||
if stats["total_count"] != stats["actual_count"]:
|
||||
diff = stats["total_count"] - stats["actual_count"]
|
||||
print(f" ⚠️ 数量不一致,差值: {diff:+d}")
|
||||
|
||||
|
||||
print("\n【评估结果统计】")
|
||||
print(f" 通过 (suitable=True): {stats['suitable_count']} 条 ({stats['suitable_rate']:.2f}%)")
|
||||
print(f" 不通过 (suitable=False): {stats['unsuitable_count']} 条 ({100 - stats['suitable_rate']:.2f}%)")
|
||||
|
||||
|
||||
print("\n【唯一性统计】")
|
||||
print(f" 唯一 (situation, style) 对: {stats['unique_pairs']} 条")
|
||||
if stats['actual_count'] > 0:
|
||||
duplicate_count = stats['actual_count'] - stats['unique_pairs']
|
||||
duplicate_rate = (duplicate_count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["actual_count"] > 0:
|
||||
duplicate_count = stats["actual_count"] - stats["unique_pairs"]
|
||||
duplicate_rate = (duplicate_count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
|
||||
print("\n【评估者统计】")
|
||||
if stats['evaluators']:
|
||||
for evaluator, count in stats['evaluators'].most_common():
|
||||
rate = (count / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["evaluators"]:
|
||||
for evaluator, count in stats["evaluators"].most_common():
|
||||
rate = (count / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" {evaluator}: {count} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
|
||||
print("\n【时间统计】")
|
||||
if stats['date_range']:
|
||||
if stats["date_range"]:
|
||||
print(f" 最早评估时间: {stats['date_range']['start']}")
|
||||
print(f" 最晚评估时间: {stats['date_range']['end']}")
|
||||
print(f" 评估时间跨度: {stats['date_range']['duration_days']} 天")
|
||||
else:
|
||||
print(" 无时间信息")
|
||||
|
||||
|
||||
print("\n【字段统计】")
|
||||
print(f" 包含 expression_id: {'是' if stats['has_expression_id'] else '否'}")
|
||||
print(f" 包含 reason: {'是' if stats['has_reason'] else '否'}")
|
||||
if stats['has_reason']:
|
||||
rate = (stats['reason_count'] / stats['actual_count'] * 100) if stats['actual_count'] > 0 else 0
|
||||
if stats["has_reason"]:
|
||||
rate = (stats["reason_count"] / stats["actual_count"] * 100) if stats["actual_count"] > 0 else 0
|
||||
print(f" 有理由的记录: {stats['reason_count']} 条 ({rate:.2f}%)")
|
||||
|
||||
|
||||
@@ -195,35 +195,35 @@ def print_summary(all_stats: List[Dict]):
|
||||
print(f"\n{'=' * 80}")
|
||||
print("汇总统计")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
|
||||
total_files = len(all_stats)
|
||||
valid_files = [s for s in all_stats if not s.get("error")]
|
||||
error_files = [s for s in all_stats if s.get("error")]
|
||||
|
||||
|
||||
print("\n【文件统计】")
|
||||
print(f" 总文件数: {total_files}")
|
||||
print(f" 成功解析: {len(valid_files)}")
|
||||
print(f" 解析失败: {len(error_files)}")
|
||||
|
||||
|
||||
if error_files:
|
||||
print("\n 失败文件列表:")
|
||||
for stats in error_files:
|
||||
print(f" - {stats['file_name']}: {stats['error']}")
|
||||
|
||||
|
||||
if not valid_files:
|
||||
print("\n没有成功解析的文件")
|
||||
return
|
||||
|
||||
|
||||
# 汇总记录统计
|
||||
total_records = sum(s['actual_count'] for s in valid_files)
|
||||
total_suitable = sum(s['suitable_count'] for s in valid_files)
|
||||
total_unsuitable = sum(s['unsuitable_count'] for s in valid_files)
|
||||
total_records = sum(s["actual_count"] for s in valid_files)
|
||||
total_suitable = sum(s["suitable_count"] for s in valid_files)
|
||||
total_unsuitable = sum(s["unsuitable_count"] for s in valid_files)
|
||||
total_unique_pairs = set()
|
||||
|
||||
|
||||
# 收集所有唯一的(situation, style)对
|
||||
for stats in valid_files:
|
||||
try:
|
||||
with open(stats['file_path'], "r", encoding="utf-8") as f:
|
||||
with open(stats["file_path"], "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
for r in results:
|
||||
@@ -231,23 +231,31 @@ def print_summary(all_stats: List[Dict]):
|
||||
total_unique_pairs.add((r["situation"], r["style"]))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
print("\n【记录汇总】")
|
||||
print(f" 总记录数: {total_records:,} 条")
|
||||
print(f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)" if total_records > 0 else " 通过: 0 条")
|
||||
print(f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)" if total_records > 0 else " 不通过: 0 条")
|
||||
print(
|
||||
f" 通过: {total_suitable:,} 条 ({total_suitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 通过: 0 条"
|
||||
)
|
||||
print(
|
||||
f" 不通过: {total_unsuitable:,} 条 ({total_unsuitable / total_records * 100:.2f}%)"
|
||||
if total_records > 0
|
||||
else " 不通过: 0 条"
|
||||
)
|
||||
print(f" 唯一 (situation, style) 对: {len(total_unique_pairs):,} 条")
|
||||
|
||||
|
||||
if total_records > 0:
|
||||
duplicate_count = total_records - len(total_unique_pairs)
|
||||
duplicate_rate = (duplicate_count / total_records * 100) if total_records > 0 else 0
|
||||
print(f" 重复记录: {duplicate_count:,} 条 ({duplicate_rate:.2f}%)")
|
||||
|
||||
|
||||
# 汇总评估者统计
|
||||
all_evaluators = Counter()
|
||||
for stats in valid_files:
|
||||
all_evaluators.update(stats['evaluators'])
|
||||
|
||||
all_evaluators.update(stats["evaluators"])
|
||||
|
||||
print("\n【评估者汇总】")
|
||||
if all_evaluators:
|
||||
for evaluator, count in all_evaluators.most_common():
|
||||
@@ -255,12 +263,12 @@ def print_summary(all_stats: List[Dict]):
|
||||
print(f" {evaluator}: {count:,} 条 ({rate:.2f}%)")
|
||||
else:
|
||||
print(" 无评估者信息")
|
||||
|
||||
|
||||
# 汇总时间范围
|
||||
all_dates = []
|
||||
for stats in valid_files:
|
||||
all_dates.extend(stats['evaluation_dates'])
|
||||
|
||||
all_dates.extend(stats["evaluation_dates"])
|
||||
|
||||
if all_dates:
|
||||
min_date = min(all_dates)
|
||||
max_date = max(all_dates)
|
||||
@@ -268,9 +276,9 @@ def print_summary(all_stats: List[Dict]):
|
||||
print(f" 最早评估时间: {min_date.isoformat()}")
|
||||
print(f" 最晚评估时间: {max_date.isoformat()}")
|
||||
print(f" 总时间跨度: {(max_date - min_date).days + 1} 天")
|
||||
|
||||
|
||||
# 文件大小汇总
|
||||
total_size = sum(s['file_size'] for s in valid_files)
|
||||
total_size = sum(s["file_size"] for s in valid_files)
|
||||
avg_size = total_size / len(valid_files) if valid_files else 0
|
||||
print("\n【文件大小汇总】")
|
||||
print(f" 总大小: {total_size:,} 字节 ({total_size / 1024 / 1024:.2f} MB)")
|
||||
@@ -282,35 +290,35 @@ def main():
|
||||
logger.info("=" * 80)
|
||||
logger.info("开始分析评估结果统计信息")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
print(f"\n✗ 错误:未找到temp目录: {TEMP_DIR}")
|
||||
logger.error(f"未找到temp目录: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
|
||||
if not json_files:
|
||||
print(f"\n✗ 错误:temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
return
|
||||
|
||||
|
||||
json_files.sort() # 按文件名排序
|
||||
|
||||
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# 分析每个文件
|
||||
all_stats = []
|
||||
for i, json_file in enumerate(json_files, 1):
|
||||
stats = analyze_single_file(json_file)
|
||||
all_stats.append(stats)
|
||||
print_file_stats(stats, index=i)
|
||||
|
||||
|
||||
# 打印汇总统计
|
||||
print_summary(all_stats)
|
||||
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print("分析完成")
|
||||
print(f"{'=' * 80}")
|
||||
@@ -318,5 +326,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@@ -171,7 +171,9 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
if not args.raw_index:
|
||||
logger.info(f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3")
|
||||
logger.info(
|
||||
f"{raw_path} 共解析出 {len(paragraphs)} 个段落,请通过 --raw-index 指定要删除的段落,例如 --raw-index 1,3"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# 解析索引列表(1-based)
|
||||
|
||||
@@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(COUNT_ANALYSIS_FILE):
|
||||
return [], set()
|
||||
|
||||
|
||||
try:
|
||||
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
def save_results(evaluation_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results
|
||||
"evaluation_results": evaluation_results,
|
||||
}
|
||||
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
|
||||
except Exception as e:
|
||||
@@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(
|
||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
||||
) -> List[Expression]:
|
||||
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目集合,用于避免重复
|
||||
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
if evaluated_pairs is None:
|
||||
evaluated_pairs = set()
|
||||
|
||||
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
|
||||
# 按count分组
|
||||
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
|
||||
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
|
||||
|
||||
|
||||
logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条")
|
||||
|
||||
|
||||
# 选择所有count>1的项目
|
||||
selected_count_gt1 = count_gt1.copy()
|
||||
|
||||
|
||||
# 选择count=1的项目,数量为count>1数量的2倍
|
||||
count_gt1_count = len(selected_count_gt1)
|
||||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
||||
logger.warning(
|
||||
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||
)
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
|
||||
# 随机选择count=1的项目
|
||||
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
|
||||
|
||||
|
||||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
@@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
@@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
@@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
@@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
||||
|
||||
logger.info(
|
||||
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||
)
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
@@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
"""
|
||||
对评估结果进行统计分析
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
if not evaluation_results:
|
||||
print("\n没有评估结果可供分析")
|
||||
return
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("统计分析结果")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 按count分组统计
|
||||
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
@@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
count_groups[count]["suitable"] += 1
|
||||
else:
|
||||
count_groups[count]["unsuitable"] += 1
|
||||
|
||||
|
||||
# 显示每个count的统计
|
||||
print("\n【按count分组统计】")
|
||||
print("-" * 60)
|
||||
@@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
suitable = group["suitable"]
|
||||
unsuitable = group["unsuitable"]
|
||||
pass_rate = (suitable / total * 100) if total > 0 else 0
|
||||
|
||||
|
||||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 比较count=1和count>1
|
||||
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
||||
|
||||
if count == 1:
|
||||
count_eq1_group["total"] += 1
|
||||
if suitable:
|
||||
@@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
count_gt1_group["suitable"] += 1
|
||||
else:
|
||||
count_gt1_group["unsuitable"] += 1
|
||||
|
||||
|
||||
print("\n【Count=1 vs Count>1 对比】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
eq1_total = count_eq1_group["total"]
|
||||
eq1_suitable = count_eq1_group["suitable"]
|
||||
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
|
||||
|
||||
|
||||
gt1_total = count_gt1_group["total"]
|
||||
gt1_suitable = count_gt1_group["suitable"]
|
||||
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
|
||||
|
||||
|
||||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
if eq1_total > 0 and gt1_total > 0:
|
||||
print("【统计显著性检验】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
# 构建2x2列联表
|
||||
# 通过 不通过
|
||||
# count=1 a b
|
||||
@@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
b = eq1_total - eq1_suitable
|
||||
c = gt1_suitable
|
||||
d = gt1_total - gt1_suitable
|
||||
|
||||
|
||||
# 计算卡方统计量(简化版,使用Pearson卡方检验)
|
||||
n = eq1_total + gt1_total
|
||||
if n > 0:
|
||||
@@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
e_b = (eq1_total * (b + d)) / n
|
||||
e_c = (gt1_total * (a + c)) / n
|
||||
e_d = (gt1_total * (b + d)) / n
|
||||
|
||||
|
||||
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5)
|
||||
min_expected = min(e_a, e_b, e_c, e_d)
|
||||
if min_expected < 5:
|
||||
print("警告:期望频数小于5,卡方检验可能不准确")
|
||||
print("建议使用Fisher精确检验")
|
||||
|
||||
|
||||
# 计算卡方值
|
||||
chi_square = 0
|
||||
if e_a > 0:
|
||||
@@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
chi_square += ((c - e_c) ** 2) / e_c
|
||||
if e_d > 0:
|
||||
chi_square += ((d - e_d) ** 2) / e_d
|
||||
|
||||
|
||||
# 自由度 = (行数-1) * (列数-1) = 1
|
||||
df = 1
|
||||
|
||||
|
||||
# 临界值(α=0.05)
|
||||
chi_square_critical_005 = 3.841
|
||||
chi_square_critical_001 = 6.635
|
||||
|
||||
|
||||
print(f"卡方统计量: {chi_square:.4f}")
|
||||
print(f"自由度: {df}")
|
||||
print(f"临界值 (α=0.05): {chi_square_critical_005}")
|
||||
print(f"临界值 (α=0.01): {chi_square_critical_001}")
|
||||
|
||||
|
||||
if chi_square >= chi_square_critical_001:
|
||||
print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)")
|
||||
elif chi_square >= chi_square_critical_005:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)")
|
||||
else:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)")
|
||||
|
||||
|
||||
# 计算差异大小
|
||||
diff = abs(eq1_pass_rate - gt1_pass_rate)
|
||||
print(f"\n合格率差异: {diff:.2f}%")
|
||||
@@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
print("数据不足,无法进行统计检验")
|
||||
else:
|
||||
print("数据不足,无法进行count=1和count>1的对比分析")
|
||||
|
||||
|
||||
# 保存统计分析结果
|
||||
analysis_result = {
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results)
|
||||
"total_evaluated": len(evaluation_results),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
|
||||
with open(analysis_file, "w", encoding="utf-8") as f:
|
||||
@@ -444,7 +444,7 @@ async def main():
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式按count分组的LLM评估和统计分析")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -452,97 +452,95 @@ async def main():
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
evaluation_results = existing_results.copy()
|
||||
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
|
||||
# 检查是否需要继续评估(检查是否还有未评估的count>1项目)
|
||||
# 先查询未评估的count>1项目数量
|
||||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions
|
||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"查询未评估项目失败: {e}")
|
||||
has_unevaluated = False
|
||||
|
||||
|
||||
if has_unevaluated:
|
||||
print("\n" + "=" * 60)
|
||||
print("开始LLM评估")
|
||||
print("=" * 60)
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
|
||||
# 创建LLM实例
|
||||
print("创建LLM实例...")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm"
|
||||
request_type="expression_evaluator_count_analysis_llm",
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(
|
||||
evaluated_pairs=evaluated_pairs
|
||||
)
|
||||
|
||||
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
else:
|
||||
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
|
||||
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条")
|
||||
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n")
|
||||
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
print(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
print(f" Situation: {expression.situation}")
|
||||
print(f" Style: {expression.style}")
|
||||
print(f" Count: {expression.count}")
|
||||
|
||||
|
||||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get('error'):
|
||||
if llm_result.get("error"):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
|
||||
batch_results.append(llm_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
|
||||
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
evaluation_results.extend(batch_results)
|
||||
|
||||
|
||||
# 保存结果
|
||||
save_results(evaluation_results)
|
||||
else:
|
||||
print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果")
|
||||
|
||||
|
||||
# 进行统计分析
|
||||
if len(evaluation_results) > 0:
|
||||
perform_statistical_analysis(evaluation_results)
|
||||
else:
|
||||
print("\n没有评估结果可供分析")
|
||||
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
@@ -553,4 +551,3 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ TEMP_DIR = os.path.join(os.path.dirname(__file__), "temp")
|
||||
def load_manual_results() -> List[Dict]:
|
||||
"""
|
||||
加载人工评估结果(自动读取temp目录下所有JSON文件并合并)
|
||||
|
||||
|
||||
Returns:
|
||||
人工评估结果列表(已去重)
|
||||
"""
|
||||
@@ -42,62 +42,62 @@ def load_manual_results() -> List[Dict]:
|
||||
print("\n✗ 错误:未找到temp目录")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
|
||||
# 查找所有JSON文件
|
||||
json_files = glob.glob(os.path.join(TEMP_DIR, "*.json"))
|
||||
|
||||
|
||||
if not json_files:
|
||||
logger.error(f"temp目录下未找到JSON文件: {TEMP_DIR}")
|
||||
print("\n✗ 错误:temp目录下未找到JSON文件")
|
||||
print(" 请先运行 evaluate_expressions_manual.py 进行人工评估")
|
||||
return []
|
||||
|
||||
|
||||
logger.info(f"找到 {len(json_files)} 个JSON文件")
|
||||
print(f"\n找到 {len(json_files)} 个JSON文件:")
|
||||
for json_file in json_files:
|
||||
print(f" - {os.path.basename(json_file)}")
|
||||
|
||||
|
||||
# 读取并合并所有JSON文件
|
||||
all_results = []
|
||||
seen_pairs: Set[Tuple[str, str]] = set() # 用于去重
|
||||
|
||||
|
||||
for json_file in json_files:
|
||||
try:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
results = data.get("manual_results", [])
|
||||
|
||||
|
||||
# 去重:使用(situation, style)作为唯一标识
|
||||
for result in results:
|
||||
if "situation" not in result or "style" not in result:
|
||||
logger.warning(f"跳过无效数据(缺少必要字段): {result}")
|
||||
continue
|
||||
|
||||
|
||||
pair = (result["situation"], result["style"])
|
||||
if pair not in seen_pairs:
|
||||
seen_pairs.add(pair)
|
||||
all_results.append(result)
|
||||
|
||||
|
||||
logger.info(f"从 {os.path.basename(json_file)} 加载了 {len(results)} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"加载文件 {json_file} 失败: {e}")
|
||||
print(f" 警告:加载文件 {os.path.basename(json_file)} 失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"成功合并 {len(all_results)} 条人工评估结果(去重后)")
|
||||
print(f"\n✓ 成功合并 {len(all_results)} 条人工评估结果(已去重)")
|
||||
|
||||
|
||||
return all_results
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
@@ -119,51 +119,50 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
@@ -172,68 +171,68 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||
async def evaluate_expression_llm(situation: str, style: str, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(situation, style, llm)
|
||||
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return {
|
||||
"situation": situation,
|
||||
"style": style,
|
||||
"suitable": suitable,
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm"
|
||||
"evaluator": "llm",
|
||||
}
|
||||
|
||||
|
||||
def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], method_name: str) -> Dict:
|
||||
"""
|
||||
对比人工评估和LLM评估的结果
|
||||
|
||||
|
||||
Args:
|
||||
manual_results: 人工评估结果列表
|
||||
llm_results: LLM评估结果列表
|
||||
method_name: 评估方法名称(用于标识)
|
||||
|
||||
|
||||
Returns:
|
||||
对比分析结果字典
|
||||
"""
|
||||
# 按(situation, style)建立映射
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
|
||||
total = len(manual_results)
|
||||
matched = 0
|
||||
true_positives = 0
|
||||
true_negatives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
|
||||
|
||||
for manual_result in manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
manual_suitable = manual_result["suitable"]
|
||||
llm_suitable = llm_result["suitable"]
|
||||
|
||||
|
||||
if manual_suitable == llm_suitable:
|
||||
matched += 1
|
||||
|
||||
|
||||
if manual_suitable and llm_suitable:
|
||||
true_positives += 1
|
||||
elif not manual_suitable and not llm_suitable:
|
||||
@@ -242,30 +241,36 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
||||
false_positives += 1
|
||||
elif manual_suitable and not llm_suitable:
|
||||
false_negatives += 1
|
||||
|
||||
|
||||
accuracy = (matched / total * 100) if total > 0 else 0
|
||||
precision = (true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
recall = (true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
precision = (
|
||||
(true_positives / (true_positives + false_positives) * 100) if (true_positives + false_positives) > 0 else 0
|
||||
)
|
||||
recall = (
|
||||
(true_positives / (true_positives + false_negatives) * 100) if (true_positives + false_negatives) > 0 else 0
|
||||
)
|
||||
f1_score = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0
|
||||
specificity = (true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
|
||||
specificity = (
|
||||
(true_negatives / (true_negatives + false_positives) * 100) if (true_negatives + false_positives) > 0 else 0
|
||||
)
|
||||
|
||||
# 计算人工效标的不合适率
|
||||
manual_unsuitable_count = true_negatives + false_positives # 人工评估不合适的总数
|
||||
manual_unsuitable_rate = (manual_unsuitable_count / total * 100) if total > 0 else 0
|
||||
|
||||
|
||||
# 计算经过LLM删除后剩余项目中的不合适率
|
||||
# 在所有项目中,移除LLM判定为不合适的项目后,剩下的项目 = TP + FP(LLM判定为合适的项目)
|
||||
# 在这些剩下的项目中,按人工评定的不合适项目 = FP(人工认为不合适,但LLM认为合适)
|
||||
llm_kept_count = true_positives + false_positives # LLM判定为合适的项目总数(保留的项目)
|
||||
llm_kept_unsuitable_rate = (false_positives / llm_kept_count * 100) if llm_kept_count > 0 else 0
|
||||
|
||||
|
||||
# 两者百分比相减(评估LLM评定修正后的不合适率是否有降低)
|
||||
rate_difference = manual_unsuitable_rate - llm_kept_unsuitable_rate
|
||||
|
||||
|
||||
random_baseline = 50.0
|
||||
accuracy_above_random = accuracy - random_baseline
|
||||
accuracy_improvement_ratio = (accuracy / random_baseline) if random_baseline > 0 else 0
|
||||
|
||||
|
||||
return {
|
||||
"method": method_name,
|
||||
"total": total,
|
||||
@@ -283,29 +288,29 @@ def compare_evaluations(manual_results: List[Dict], llm_results: List[Dict], met
|
||||
"specificity": specificity,
|
||||
"manual_unsuitable_rate": manual_unsuitable_rate,
|
||||
"llm_kept_unsuitable_rate": llm_kept_unsuitable_rate,
|
||||
"rate_difference": rate_difference
|
||||
"rate_difference": rate_difference,
|
||||
}
|
||||
|
||||
|
||||
async def main(count: int | None = None):
|
||||
"""
|
||||
主函数
|
||||
|
||||
|
||||
Args:
|
||||
count: 随机选取的数据条数,如果为None则使用全部数据
|
||||
"""
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式LLM评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 1. 加载人工评估结果
|
||||
print("\n步骤1: 加载人工评估结果")
|
||||
manual_results = load_manual_results()
|
||||
if not manual_results:
|
||||
return
|
||||
|
||||
|
||||
print(f"成功加载 {len(manual_results)} 条人工评估结果")
|
||||
|
||||
|
||||
# 如果指定了数量,随机选择指定数量的数据
|
||||
if count is not None:
|
||||
if count <= 0:
|
||||
@@ -317,7 +322,7 @@ async def main(count: int | None = None):
|
||||
random.seed() # 使用系统时间作为随机种子
|
||||
manual_results = random.sample(manual_results, count)
|
||||
print(f"随机选取 {len(manual_results)} 条数据进行评估")
|
||||
|
||||
|
||||
# 验证数据完整性
|
||||
valid_manual_results = []
|
||||
for r in manual_results:
|
||||
@@ -325,62 +330,58 @@ async def main(count: int | None = None):
|
||||
valid_manual_results.append(r)
|
||||
else:
|
||||
logger.warning(f"跳过无效数据: {r}")
|
||||
|
||||
|
||||
if len(valid_manual_results) != len(manual_results):
|
||||
print(f"警告:{len(manual_results) - len(valid_manual_results)} 条数据缺少必要字段,已跳过")
|
||||
|
||||
|
||||
print(f"有效数据: {len(valid_manual_results)} 条")
|
||||
|
||||
|
||||
# 2. 创建LLM实例并评估
|
||||
print("\n步骤2: 创建LLM实例")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_llm"
|
||||
)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_evaluator_llm")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
|
||||
print("\n步骤3: 开始LLM评估")
|
||||
llm_results = []
|
||||
for i, manual_result in enumerate(valid_manual_results, 1):
|
||||
print(f"LLM评估进度: {i}/{len(valid_manual_results)}")
|
||||
llm_results.append(await evaluate_expression_llm(
|
||||
manual_result["situation"],
|
||||
manual_result["style"],
|
||||
llm
|
||||
))
|
||||
llm_results.append(await evaluate_expression_llm(manual_result["situation"], manual_result["style"], llm))
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# 5. 输出FP和FN项目(在评估结果之前)
|
||||
llm_dict = {(r["situation"], r["style"]): r for r in llm_results}
|
||||
|
||||
|
||||
# 5.1 输出FP项目(人工评估不通过但LLM误判为通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估不通过但LLM误判为通过的项目(FP - False Positive)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
fp_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
# 人工评估不通过,但LLM评估通过(FP情况)
|
||||
if not manual_result["suitable"] and llm_result["suitable"]:
|
||||
fp_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
fp_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fp_items:
|
||||
print(f"\n共找到 {len(fp_items)} 条误判项目:\n")
|
||||
for idx, item in enumerate(fp_items, 1):
|
||||
@@ -389,36 +390,38 @@ async def main(count: int | None = None):
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 不通过 ❌")
|
||||
print("LLM评估: 通过 ✅ (误判)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误判项目(所有人工评估不通过的项目都被LLM正确识别为不通过)")
|
||||
|
||||
|
||||
# 5.2 输出FN项目(人工评估通过但LLM误判为不通过)
|
||||
print("\n" + "=" * 60)
|
||||
print("人工评估通过但LLM误判为不通过的项目(FN - False Negative)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
fn_items = []
|
||||
for manual_result in valid_manual_results:
|
||||
pair = (manual_result["situation"], manual_result["style"])
|
||||
llm_result = llm_dict.get(pair)
|
||||
if llm_result is None:
|
||||
continue
|
||||
|
||||
|
||||
# 人工评估通过,但LLM评估不通过(FN情况)
|
||||
if manual_result["suitable"] and not llm_result["suitable"]:
|
||||
fn_items.append({
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error")
|
||||
})
|
||||
|
||||
fn_items.append(
|
||||
{
|
||||
"situation": manual_result["situation"],
|
||||
"style": manual_result["style"],
|
||||
"manual_suitable": manual_result["suitable"],
|
||||
"llm_suitable": llm_result["suitable"],
|
||||
"llm_reason": llm_result.get("reason", "未提供理由"),
|
||||
"llm_error": llm_result.get("error"),
|
||||
}
|
||||
)
|
||||
|
||||
if fn_items:
|
||||
print(f"\n共找到 {len(fn_items)} 条误删项目:\n")
|
||||
for idx, item in enumerate(fn_items, 1):
|
||||
@@ -427,33 +430,41 @@ async def main(count: int | None = None):
|
||||
print(f"Style: {item['style']}")
|
||||
print("人工评估: 通过 ✅")
|
||||
print("LLM评估: 不通过 ❌ (误删)")
|
||||
if item.get('llm_error'):
|
||||
if item.get("llm_error"):
|
||||
print(f"LLM错误: {item['llm_error']}")
|
||||
print(f"LLM理由: {item['llm_reason']}")
|
||||
print()
|
||||
else:
|
||||
print("\n✓ 没有误删项目(所有人工评估通过的项目都被LLM正确识别为通过)")
|
||||
|
||||
|
||||
# 6. 对比分析并输出结果
|
||||
comparison = compare_evaluations(valid_manual_results, llm_results, "LLM评估")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估结果(以人工评估为标准)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 详细评估结果(核心指标优先)
|
||||
print(f"\n--- {comparison['method']} ---")
|
||||
print(f" 总数: {comparison['total']} 条")
|
||||
print()
|
||||
# print(" 【核心能力指标】")
|
||||
print(f" 特定负类召回率: {comparison['specificity']:.2f}% (将不合适项目正确提取出来的能力)")
|
||||
print(f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个")
|
||||
print(
|
||||
f" - 计算: TN / (TN + FP) = {comparison['true_negatives']} / ({comparison['true_negatives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_negatives'] + comparison['false_positives']} 个实际不合适的项目中,正确识别出 {comparison['true_negatives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['specificity'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(f" 召回率: {comparison['recall']:.2f}% (尽可能少的误删合适项目的能力)")
|
||||
print(f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})")
|
||||
print(f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个")
|
||||
print(
|
||||
f" - 计算: TP / (TP + FN) = {comparison['true_positives']} / ({comparison['true_positives']} + {comparison['false_negatives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在 {comparison['true_positives'] + comparison['false_negatives']} 个实际合适的项目中,正确识别出 {comparison['true_positives']} 个"
|
||||
)
|
||||
# print(f" - 随机水平: 50.00% (当前高于随机: {comparison['recall'] - 50.0:+.2f}%)")
|
||||
print()
|
||||
print(" 【其他指标】")
|
||||
@@ -464,12 +475,18 @@ async def main(count: int | None = None):
|
||||
print()
|
||||
print(" 【不合适率分析】")
|
||||
print(f" 人工效标的不合适率: {comparison['manual_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}")
|
||||
print(
|
||||
f" - 计算: (TN + FP) / 总数 = ({comparison['true_negatives']} + {comparison['false_positives']}) / {comparison['total']}"
|
||||
)
|
||||
print(f" - 含义: 在人工评估中,有 {comparison['manual_unsuitable_rate']:.2f}% 的项目被判定为不合适")
|
||||
print()
|
||||
print(f" 经过LLM删除后剩余项目中的不合适率: {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})")
|
||||
print(f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
print(
|
||||
f" - 计算: FP / (TP + FP) = {comparison['false_positives']} / ({comparison['true_positives']} + {comparison['false_positives']})"
|
||||
)
|
||||
print(
|
||||
f" - 含义: 在所有项目中,移除LLM判定为不合适的项目后,在剩下的 {comparison['true_positives'] + comparison['false_positives']} 个项目中,人工认为不合适的项目占 {comparison['llm_kept_unsuitable_rate']:.2f}%"
|
||||
)
|
||||
print()
|
||||
# print(f" 两者百分比差值: {comparison['rate_difference']:+.2f}%")
|
||||
# print(f" - 计算: 人工效标不合适率 - LLM删除后剩余项目不合适率 = {comparison['manual_unsuitable_rate']:.2f}% - {comparison['llm_kept_unsuitable_rate']:.2f}%")
|
||||
@@ -480,21 +497,22 @@ async def main(count: int | None = None):
|
||||
print(f" TN (正确识别为不合适): {comparison['true_negatives']} ⭐")
|
||||
print(f" FP (误判为合适): {comparison['false_positives']} ⚠️")
|
||||
print(f" FN (误删合适项目): {comparison['false_negatives']} ⚠️")
|
||||
|
||||
|
||||
# 7. 保存结果到JSON文件
|
||||
output_file = os.path.join(project_root, "data", "expression_evaluation_llm.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"manual_results": valid_manual_results,
|
||||
"llm_results": llm_results,
|
||||
"comparison": comparison
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
json.dump(
|
||||
{"manual_results": valid_manual_results, "llm_results": llm_results, "comparison": comparison},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
logger.info(f"\n评估结果已保存到: {output_file}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存结果到文件失败: {e}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("评估完成")
|
||||
print("=" * 60)
|
||||
@@ -509,15 +527,9 @@ if __name__ == "__main__":
|
||||
python evaluate_expressions_llm_v6.py # 使用全部数据
|
||||
python evaluate_expressions_llm_v6.py -n 50 # 随机选取50条数据
|
||||
python evaluate_expressions_llm_v6.py --count 100 # 随机选取100条数据
|
||||
"""
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--count",
|
||||
type=int,
|
||||
default=None,
|
||||
help="随机选取的数据条数(默认:使用全部数据)"
|
||||
)
|
||||
|
||||
parser.add_argument("-n", "--count", type=int, default=None, help="随机选取的数据条数(默认:使用全部数据)")
|
||||
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(count=args.count))
|
||||
|
||||
|
||||
@@ -32,13 +32,13 @@ MANUAL_EVAL_FILE = os.path.join(TEMP_DIR, "manual_evaluation_results.json")
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(MANUAL_EVAL_FILE):
|
||||
return [], set()
|
||||
|
||||
|
||||
try:
|
||||
with open(MANUAL_EVAL_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@@ -55,22 +55,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
def save_results(manual_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
|
||||
Args:
|
||||
manual_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(manual_results),
|
||||
"manual_results": manual_results
|
||||
"manual_results": manual_results,
|
||||
}
|
||||
|
||||
|
||||
with open(MANUAL_EVAL_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info(f"评估结果已保存到: {MANUAL_EVAL_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(manual_results)} 条)")
|
||||
except Exception as e:
|
||||
@@ -81,45 +81,43 @@ def save_results(manual_results: List[Dict]):
|
||||
def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_size: int = 10) -> List[Expression]:
|
||||
"""
|
||||
获取未评估的表达方式
|
||||
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目(situation, style)元组集合
|
||||
batch_size: 每次获取的数量
|
||||
|
||||
|
||||
Returns:
|
||||
未评估的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
|
||||
# 过滤出未评估的项目:匹配 situation 和 style 均一致
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.info("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
|
||||
# 如果未评估数量少于请求数量,返回所有
|
||||
if len(unevaluated) <= batch_size:
|
||||
logger.info(f"剩余 {len(unevaluated)} 条未评估项目,全部返回")
|
||||
return unevaluated
|
||||
|
||||
|
||||
# 随机选择指定数量
|
||||
selected = random.sample(unevaluated, batch_size)
|
||||
logger.info(f"从 {len(unevaluated)} 条未评估项目中随机选择了 {len(selected)} 条")
|
||||
return selected
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取未评估表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
@@ -127,12 +125,12 @@ def get_unevaluated_expressions(evaluated_pairs: Set[Tuple[str, str]], batch_siz
|
||||
def manual_evaluate_expression(expression: Expression, index: int, total: int) -> Dict:
|
||||
"""
|
||||
人工评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
index: 当前索引(从1开始)
|
||||
total: 总数
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典,如果用户退出则返回 None
|
||||
"""
|
||||
@@ -146,38 +144,38 @@ def manual_evaluate_expression(expression: Expression, index: int, total: int) -
|
||||
print(" 输入 'n' 或 'no' 或 '0' 表示不合适(不通过)")
|
||||
print(" 输入 'q' 或 'quit' 退出评估")
|
||||
print(" 输入 's' 或 'skip' 跳过当前项目")
|
||||
|
||||
|
||||
while True:
|
||||
user_input = input("\n您的评估 (y/n/q/s): ").strip().lower()
|
||||
|
||||
if user_input in ['q', 'quit']:
|
||||
|
||||
if user_input in ["q", "quit"]:
|
||||
print("退出评估")
|
||||
return None
|
||||
|
||||
if user_input in ['s', 'skip']:
|
||||
|
||||
if user_input in ["s", "skip"]:
|
||||
print("跳过当前项目")
|
||||
return "skip"
|
||||
|
||||
if user_input in ['y', 'yes', '1', '是', '通过']:
|
||||
|
||||
if user_input in ["y", "yes", "1", "是", "通过"]:
|
||||
suitable = True
|
||||
break
|
||||
elif user_input in ['n', 'no', '0', '否', '不通过']:
|
||||
elif user_input in ["n", "no", "0", "否", "不通过"]:
|
||||
suitable = False
|
||||
break
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n/q/s)")
|
||||
|
||||
|
||||
result = {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"suitable": suitable,
|
||||
"reason": None,
|
||||
"evaluator": "manual",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
print(f"\n✓ 已记录:{'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -186,7 +184,7 @@ def main():
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式人工评估")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -194,41 +192,41 @@ def main():
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
manual_results = existing_results.copy()
|
||||
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始人工评估")
|
||||
print("=" * 60)
|
||||
print("提示:可以随时输入 'q' 退出,输入 's' 跳过当前项目")
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
|
||||
batch_size = 10
|
||||
batch_count = 0
|
||||
|
||||
|
||||
while True:
|
||||
# 获取未评估的项目
|
||||
expressions = get_unevaluated_expressions(evaluated_pairs, batch_size)
|
||||
|
||||
|
||||
if not expressions:
|
||||
print("\n" + "=" * 60)
|
||||
print("所有项目都已评估完成!")
|
||||
print("=" * 60)
|
||||
break
|
||||
|
||||
|
||||
batch_count += 1
|
||||
print(f"\n--- 批次 {batch_count}:评估 {len(expressions)} 条项目 ---")
|
||||
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
manual_result = manual_evaluate_expression(expression, i, len(expressions))
|
||||
|
||||
|
||||
if manual_result is None:
|
||||
# 用户退出
|
||||
print("\n评估已中断")
|
||||
@@ -237,34 +235,34 @@ def main():
|
||||
manual_results.extend(batch_results)
|
||||
save_results(manual_results)
|
||||
return
|
||||
|
||||
|
||||
if manual_result == "skip":
|
||||
# 跳过当前项目
|
||||
continue
|
||||
|
||||
|
||||
batch_results.append(manual_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((manual_result["situation"], manual_result["style"]))
|
||||
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
manual_results.extend(batch_results)
|
||||
|
||||
|
||||
# 保存结果
|
||||
save_results(manual_results)
|
||||
|
||||
|
||||
print(f"\n当前批次完成,已评估总数: {len(manual_results)} 条")
|
||||
|
||||
|
||||
# 询问是否继续
|
||||
while True:
|
||||
continue_input = input("\n是否继续评估下一批?(y/n): ").strip().lower()
|
||||
if continue_input in ['y', 'yes', '1', '是', '继续']:
|
||||
if continue_input in ["y", "yes", "1", "是", "继续"]:
|
||||
break
|
||||
elif continue_input in ['n', 'no', '0', '否', '退出']:
|
||||
elif continue_input in ["n", "no", "0", "否", "退出"]:
|
||||
print("\n评估结束")
|
||||
return
|
||||
else:
|
||||
print("输入无效,请重新输入 (y/n)")
|
||||
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
@@ -275,4 +273,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -134,9 +134,7 @@ def handle_import_openie(
|
||||
# 在非交互模式下,不再询问用户,而是直接报错终止
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。"
|
||||
)
|
||||
logger.error("检测到非法文段且当前处于非交互模式,无法询问是否删除非法文段,导入终止。")
|
||||
sys.exit(1)
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
@@ -189,9 +187,7 @@ def handle_import_openie(
|
||||
async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过导入开销确认提示,直接开始执行 OpenIE 导入。")
|
||||
else:
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
@@ -261,10 +257,7 @@ async def main_async(non_interactive: bool = False) -> bool: # sourcery skip: d
|
||||
def main(argv: Optional[list[str]] = None) -> None:
|
||||
"""主函数 - 解析参数并运行异步主流程。"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,"
|
||||
"将其导入到 LPMM 的向量库与知识图中。"
|
||||
)
|
||||
description=("OpenIE 导入脚本:读取 data/openie 中的 OpenIE JSON 批次,将其导入到 LPMM 的向量库与知识图中。")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non-interactive",
|
||||
|
||||
@@ -123,9 +123,7 @@ def _run(non_interactive: bool = False) -> None: # sourcery skip: comprehension
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
if non_interactive:
|
||||
logger.warning(
|
||||
"当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。"
|
||||
)
|
||||
logger.warning("当前处于非交互模式,将跳过费用与时长确认提示,直接开始进行实体提取操作。")
|
||||
else:
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
|
||||
@@ -68,4 +68,3 @@ def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ except ImportError as e:
|
||||
|
||||
logger = get_logger("lpmm_interactive_manager")
|
||||
|
||||
|
||||
async def interactive_add():
|
||||
"""交互式导入知识"""
|
||||
print("\n" + "=" * 40)
|
||||
@@ -38,7 +39,7 @@ async def interactive_add():
|
||||
print(" - 支持多段落,段落间请保留空行。")
|
||||
print(" - 输入完成后,在新起的一行输入 'EOF' 并回车结束输入。")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
lines = []
|
||||
while True:
|
||||
try:
|
||||
@@ -48,7 +49,7 @@ async def interactive_add():
|
||||
lines.append(line)
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
text = "\n".join(lines).strip()
|
||||
if not text:
|
||||
print("\n[!] 内容为空,操作已取消。")
|
||||
@@ -58,7 +59,7 @@ async def interactive_add():
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.add_content(text)
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 实际新增段落数: {result.get('count', 0)}")
|
||||
@@ -68,6 +69,7 @@ async def interactive_add():
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"add_content 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_delete():
|
||||
"""交互式删除知识"""
|
||||
print("\n" + "=" * 40)
|
||||
@@ -77,10 +79,10 @@ async def interactive_delete():
|
||||
print(" 1. 关键词模糊匹配(删除包含关键词的所有段落)")
|
||||
print(" 2. 完整文段匹配(删除完全匹配的段落)")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
mode = input("请选择删除模式 (1/2): ").strip()
|
||||
exact_match = False
|
||||
|
||||
|
||||
if mode == "2":
|
||||
exact_match = True
|
||||
print("\n[完整文段匹配模式]")
|
||||
@@ -102,14 +104,18 @@ async def interactive_delete():
|
||||
print("\n[!] 无效选择,默认使用关键词模糊匹配模式。")
|
||||
print("\n[关键词模糊匹配模式]")
|
||||
keyword = input("请输入匹配关键词: ").strip()
|
||||
|
||||
|
||||
if not keyword:
|
||||
print("\n[!] 输入为空,操作已取消。")
|
||||
return
|
||||
|
||||
|
||||
print("-" * 40)
|
||||
confirm = input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ").strip().lower()
|
||||
if confirm != 'y':
|
||||
confirm = (
|
||||
input(f"危险确认:确定要删除所有匹配 '{keyword[:50]}{'...' if len(keyword) > 50 else ''}' 的知识吗?(y/N): ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if confirm != "y":
|
||||
print("\n[!] 已取消删除操作。")
|
||||
return
|
||||
|
||||
@@ -117,7 +123,7 @@ async def interactive_delete():
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.delete(keyword, exact_match=exact_match)
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
print(f" 删除条数: {result.get('deleted_count', 0)}")
|
||||
@@ -129,6 +135,7 @@ async def interactive_delete():
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"delete 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_clear():
|
||||
"""交互式清空知识库"""
|
||||
print("\n" + "=" * 40)
|
||||
@@ -141,40 +148,45 @@ async def interactive_clear():
|
||||
print(" - 整个知识图谱")
|
||||
print(" - 此操作不可恢复!")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
# 双重确认
|
||||
confirm1 = input("⚠️ 第一次确认:确定要清空整个知识库吗?(输入 'YES' 继续): ").strip()
|
||||
if confirm1 != "YES":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
|
||||
print("\n" + "=" * 40)
|
||||
confirm2 = input("⚠️ 第二次确认:此操作不可恢复,请再次输入 'CLEAR' 确认: ").strip()
|
||||
if confirm2 != "CLEAR":
|
||||
print("\n[!] 已取消清空操作。")
|
||||
return
|
||||
|
||||
|
||||
print("\n[进度] 正在清空知识库...")
|
||||
try:
|
||||
# 使用 lpmm_ops.py 中的接口
|
||||
result = await lpmm_ops.clear_all()
|
||||
|
||||
|
||||
if result["status"] == "success":
|
||||
print(f"\n[√] 成功:{result['message']}")
|
||||
stats = result.get("stats", {})
|
||||
before = stats.get("before", {})
|
||||
after = stats.get("after", {})
|
||||
print("\n[统计信息]")
|
||||
print(f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}")
|
||||
print(f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}")
|
||||
print(
|
||||
f" 清空前: 段落={before.get('paragraphs', 0)}, 实体={before.get('entities', 0)}, "
|
||||
f"关系={before.get('relations', 0)}, KG节点={before.get('kg_nodes', 0)}, KG边={before.get('kg_edges', 0)}"
|
||||
)
|
||||
print(
|
||||
f" 清空后: 段落={after.get('paragraphs', 0)}, 实体={after.get('entities', 0)}, "
|
||||
f"关系={after.get('relations', 0)}, KG节点={after.get('kg_nodes', 0)}, KG边={after.get('kg_edges', 0)}"
|
||||
)
|
||||
else:
|
||||
print(f"\n[×] 失败:{result['message']}")
|
||||
except Exception as e:
|
||||
print(f"\n[×] 发生异常: {e}")
|
||||
logger.error(f"clear_all 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def interactive_search():
|
||||
"""交互式查询知识"""
|
||||
print("\n" + "=" * 40)
|
||||
@@ -182,25 +194,25 @@ async def interactive_search():
|
||||
print("=" * 40)
|
||||
print("说明:输入查询问题或关键词,系统会返回相关的知识段落。")
|
||||
print("-" * 40)
|
||||
|
||||
|
||||
# 确保 LPMM 已初始化
|
||||
if not global_config.lpmm_knowledge.enable:
|
||||
print("\n[!] 警告:LPMM 知识库在配置中未启用。")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
lpmm_start_up()
|
||||
except Exception as e:
|
||||
print(f"\n[!] LPMM 初始化失败: {e}")
|
||||
logger.error(f"LPMM 初始化失败: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
|
||||
query = input("请输入查询问题或关键词: ").strip()
|
||||
|
||||
|
||||
if not query:
|
||||
print("\n[!] 查询内容为空,操作已取消。")
|
||||
return
|
||||
|
||||
|
||||
# 询问返回条数
|
||||
print("-" * 40)
|
||||
limit_str = input("希望返回的相关知识条数(默认3,直接回车使用默认值): ").strip()
|
||||
@@ -210,11 +222,11 @@ async def interactive_search():
|
||||
except ValueError:
|
||||
limit = 3
|
||||
print("[!] 输入无效,使用默认值 3。")
|
||||
|
||||
|
||||
print("\n[进度] 正在查询知识库...")
|
||||
try:
|
||||
result = await query_lpmm_knowledge(query, limit=limit)
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("[查询结果]")
|
||||
print("=" * 60)
|
||||
@@ -224,6 +236,7 @@ async def interactive_search():
|
||||
print(f"\n[×] 查询失败: {e}")
|
||||
logger.error(f"查询异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""主循环"""
|
||||
while True:
|
||||
@@ -236,9 +249,9 @@ async def main():
|
||||
print("║ 4. 清空知识库 (Clear All) ⚠️ ║")
|
||||
print("║ 0. 退出 (Exit) ║")
|
||||
print("╚" + "═" * 38 + "╝")
|
||||
|
||||
|
||||
choice = input("请选择操作编号: ").strip()
|
||||
|
||||
|
||||
if choice == "1":
|
||||
await interactive_add()
|
||||
elif choice == "2":
|
||||
@@ -253,6 +266,7 @@ async def main():
|
||||
else:
|
||||
print("\n[!] 无效的选择,请输入 0, 1, 2, 3 或 4。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 运行主循环
|
||||
@@ -262,4 +276,3 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
print(f"\n[!] 程序运行出错: {e}")
|
||||
logger.error(f"Main loop 异常: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -69,15 +69,10 @@ def _check_before_info_extract(non_interactive: bool = False) -> bool:
|
||||
raw_dir = Path(PROJECT_ROOT) / "data" / "lpmm_raw_data"
|
||||
txt_files = list(raw_dir.glob("*.txt"))
|
||||
if not txt_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,"
|
||||
"info_extraction 可能立即退出或无数据可处理。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {raw_dir} 下找到任何 .txt 原始语料文件,info_extraction 可能立即退出或无数据可处理。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。"
|
||||
)
|
||||
logger.error("非交互模式下要求原始语料目录中已存在可用的 .txt 文件,请先准备好数据再重试。")
|
||||
return False
|
||||
cont = input("仍然继续执行信息提取吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
@@ -89,15 +84,10 @@ def _check_before_import_openie(non_interactive: bool = False) -> bool:
|
||||
openie_dir = Path(PROJECT_ROOT) / "data" / "openie"
|
||||
json_files = list(openie_dir.glob("*.json"))
|
||||
if not json_files:
|
||||
msg = (
|
||||
f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,"
|
||||
"import_openie 可能会因为找不到批次而失败。"
|
||||
)
|
||||
msg = f"[WARN] 未在 {openie_dir} 下找到任何 OpenIE JSON 文件,import_openie 可能会因为找不到批次而失败。"
|
||||
print(msg)
|
||||
if non_interactive:
|
||||
logger.error(
|
||||
"非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。"
|
||||
)
|
||||
logger.error("非交互模式下要求 data/openie 目录中已存在可用的 OpenIE JSON 文件,请先执行信息提取脚本。")
|
||||
return False
|
||||
cont = input("仍然继续执行导入吗?(y/n): ").strip().lower()
|
||||
return cont == "y"
|
||||
@@ -108,10 +98,7 @@ def _warn_if_lpmm_disabled() -> None:
|
||||
"""在部分操作前提醒 lpmm_knowledge.enable 状态。"""
|
||||
try:
|
||||
if not getattr(global_config.lpmm_knowledge, "enable", False):
|
||||
print(
|
||||
"[WARN] 当前配置 lpmm_knowledge.enable = false,"
|
||||
"刷新或检索测试可能无法在聊天侧真正启用 LPMM。"
|
||||
)
|
||||
print("[WARN] 当前配置 lpmm_knowledge.enable = false,刷新或检索测试可能无法在聊天侧真正启用 LPMM。")
|
||||
except Exception:
|
||||
# 配置异常时不阻断主流程,仅忽略提示
|
||||
pass
|
||||
@@ -131,10 +118,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
if action == "prepare_raw":
|
||||
logger.info("开始预处理原始语料 (data/lpmm_raw_data/*.txt)...")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[PREPARE_RAW] 完成原始语料预处理:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
elif action == "info_extract":
|
||||
if not _check_before_info_extract("--non-interactive" in extra_args):
|
||||
print("已根据用户选择,取消执行信息提取。")
|
||||
@@ -164,10 +148,7 @@ def run_action(action: str, extra_args: Optional[List[str]] = None) -> None:
|
||||
# 一键流水线:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新
|
||||
logger.info("开始 full_import:预处理原始语料 -> 信息抽取 -> 导入 -> 刷新")
|
||||
sha_list, raw_data = load_raw_data()
|
||||
print(
|
||||
f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,"
|
||||
f"去重后哈希数 {len(sha_list)}。"
|
||||
)
|
||||
print(f"\n[FULL_IMPORT] 原始语料预处理完成:共 {len(raw_data)} 条段落,去重后哈希数 {len(sha_list)}。")
|
||||
non_interactive = "--non-interactive" in extra_args
|
||||
if not _check_before_info_extract(non_interactive):
|
||||
print("已根据用户选择,取消 full_import(信息提取阶段被取消)。")
|
||||
@@ -345,9 +326,9 @@ def _interactive_build_delete_args() -> List[str]:
|
||||
)
|
||||
|
||||
# 快速选项:按推荐方式清理所有相关实体/关系
|
||||
quick_all = input(
|
||||
"是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): "
|
||||
).strip().lower()
|
||||
quick_all = (
|
||||
input("是否使用推荐策略:同时删除关联的实体向量/节点、关系向量,并清理孤立实体?(Y/n): ").strip().lower()
|
||||
)
|
||||
if quick_all in ("", "y", "yes"):
|
||||
args.extend(["--delete-entities", "--delete-relations", "--remove-orphan-entities"])
|
||||
else:
|
||||
@@ -375,9 +356,7 @@ def _interactive_build_delete_args() -> List[str]:
|
||||
|
||||
def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
"""为 inspect_lpmm_batch 构造 --openie-file 参数。"""
|
||||
path = _interactive_choose_openie_file(
|
||||
"请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):"
|
||||
)
|
||||
path = _interactive_choose_openie_file("请输入要检查的 OpenIE JSON 文件路径(回车跳过,由子脚本自行交互):")
|
||||
if not path:
|
||||
return []
|
||||
return ["--openie-file", path]
|
||||
@@ -385,11 +364,7 @@ def _interactive_build_batch_inspect_args() -> List[str]:
|
||||
|
||||
def _interactive_build_test_args() -> List[str]:
|
||||
"""为 test_lpmm_retrieval 构造自定义测试用例参数。"""
|
||||
print(
|
||||
"\n[TEST] 你可以:\n"
|
||||
"- 直接回车使用内置的默认测试用例;\n"
|
||||
"- 或者输入一条自定义问题,并指定期望命中的关键字。"
|
||||
)
|
||||
print("\n[TEST] 你可以:\n- 直接回车使用内置的默认测试用例;\n- 或者输入一条自定义问题,并指定期望命中的关键字。")
|
||||
query = input("请输入自定义测试问题(回车则使用默认用例):").strip()
|
||||
if not query:
|
||||
return []
|
||||
@@ -422,9 +397,7 @@ def _run_embedding_helper() -> None:
|
||||
print(f"- 当前配置中的嵌入维度 (lpmm_knowledge.embedding_dimension): {current_dim}")
|
||||
print(f"- 测试文件路径: {EMBEDDING_TEST_FILE}")
|
||||
|
||||
new_dim = input(
|
||||
"\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):"
|
||||
).strip()
|
||||
new_dim = input("\n如果你计划更换嵌入模型,请在此输入“新的嵌入维度”(仅用于记录与提示,回车则跳过):").strip()
|
||||
if new_dim and not new_dim.isdigit():
|
||||
print("输入的维度不是纯数字,已取消操作。")
|
||||
return
|
||||
@@ -537,5 +510,3 @@ def main(argv: Optional[list[str]] = None) -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
@@ -28,53 +28,55 @@ from maim_message import UserInfo, GroupInfo
|
||||
|
||||
logger = get_logger("test_memory_retrieval")
|
||||
|
||||
|
||||
# 使用 importlib 动态导入,避免循环导入问题
|
||||
def _import_memory_retrieval():
|
||||
"""使用 importlib 动态导入 memory_retrieval 模块,避免循环导入"""
|
||||
try:
|
||||
# 先导入 prompt_builder,检查 prompt 是否已经初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
|
||||
# 检查 memory_retrieval 相关的 prompt 是否已经注册
|
||||
# 如果已经注册,说明模块可能已经通过其他路径初始化过了
|
||||
prompt_already_init = "memory_retrieval_question_prompt" in global_prompt_manager._prompts
|
||||
|
||||
|
||||
module_name = "src.memory_system.memory_retrieval"
|
||||
|
||||
|
||||
# 如果 prompt 已经初始化,尝试直接使用已加载的模块
|
||||
if prompt_already_init and module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
return (
|
||||
existing_module.init_memory_retrieval_prompt,
|
||||
existing_module._react_agent_solve_question,
|
||||
existing_module._process_single_question,
|
||||
)
|
||||
|
||||
|
||||
# 如果模块已经在 sys.modules 中但部分初始化,先移除它
|
||||
if module_name in sys.modules:
|
||||
existing_module = sys.modules[module_name]
|
||||
if not hasattr(existing_module, 'init_memory_retrieval_prompt'):
|
||||
if not hasattr(existing_module, "init_memory_retrieval_prompt"):
|
||||
# 模块部分初始化,移除它
|
||||
logger.warning(f"检测到部分初始化的模块 {module_name},尝试重新导入")
|
||||
del sys.modules[module_name]
|
||||
# 清理可能相关的部分初始化模块
|
||||
keys_to_remove = []
|
||||
for key in sys.modules.keys():
|
||||
if key.startswith('src.memory_system.') and key != 'src.memory_system':
|
||||
if key.startswith("src.memory_system.") and key != "src.memory_system":
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
try:
|
||||
del sys.modules[key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
# 在导入 memory_retrieval 之前,先确保所有可能触发循环导入的模块都已完全加载
|
||||
# 这些模块在导入时可能会触发 memory_retrieval 的导入,所以我们需要先加载它们
|
||||
try:
|
||||
# 先导入可能触发循环导入的模块,让它们完成初始化
|
||||
import src.config.config
|
||||
import src.chat.utils.prompt_builder
|
||||
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
# 如果它们已经导入,就确保它们完全初始化
|
||||
# 尝试导入可能触发循环导入的模块(这些模块可能在模块级别导入了 memory_retrieval)
|
||||
@@ -89,11 +91,11 @@ def _import_memory_retrieval():
|
||||
pass # 如果导入失败,继续
|
||||
except Exception as e:
|
||||
logger.warning(f"预加载依赖模块时出现警告: {e}")
|
||||
|
||||
|
||||
# 现在尝试导入 memory_retrieval
|
||||
# 如果此时仍然触发循环导入,说明有其他模块在模块级别导入了 memory_retrieval
|
||||
memory_retrieval_module = importlib.import_module(module_name)
|
||||
|
||||
|
||||
return (
|
||||
memory_retrieval_module.init_memory_retrieval_prompt,
|
||||
memory_retrieval_module._react_agent_solve_question,
|
||||
@@ -126,16 +128,16 @@ def create_test_chat_stream(chat_id: str = "test_memory_retrieval") -> ChatStrea
|
||||
|
||||
def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
"""获取从指定时间开始的token使用情况
|
||||
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
包含token使用统计的字典
|
||||
"""
|
||||
try:
|
||||
start_datetime = datetime.fromtimestamp(start_time)
|
||||
|
||||
|
||||
# 查询从开始时间到现在的所有memory相关的token使用记录
|
||||
records = (
|
||||
LLMUsage.select()
|
||||
@@ -150,21 +152,21 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
)
|
||||
.order_by(LLMUsage.timestamp.asc())
|
||||
)
|
||||
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
total_cost = 0.0
|
||||
request_count = 0
|
||||
model_usage = {} # 按模型统计
|
||||
|
||||
|
||||
for record in records:
|
||||
total_prompt_tokens += record.prompt_tokens or 0
|
||||
total_completion_tokens += record.completion_tokens or 0
|
||||
total_tokens += record.total_tokens or 0
|
||||
total_cost += record.cost or 0.0
|
||||
request_count += 1
|
||||
|
||||
|
||||
# 按模型统计
|
||||
model_name = record.model_name or "unknown"
|
||||
if model_name not in model_usage:
|
||||
@@ -180,7 +182,7 @@ def get_token_usage_since(start_time: float) -> Dict[str, Any]:
|
||||
model_usage[model_name]["total_tokens"] += record.total_tokens or 0
|
||||
model_usage[model_name]["cost"] += record.cost or 0.0
|
||||
model_usage[model_name]["request_count"] += 1
|
||||
|
||||
|
||||
return {
|
||||
"total_prompt_tokens": total_prompt_tokens,
|
||||
"total_completion_tokens": total_completion_tokens,
|
||||
@@ -205,25 +207,25 @@ def format_thinking_steps(thinking_steps: list) -> str:
|
||||
"""格式化思考步骤为可读字符串"""
|
||||
if not thinking_steps:
|
||||
return "无思考步骤"
|
||||
|
||||
|
||||
lines = []
|
||||
for step in thinking_steps:
|
||||
iteration = step.get("iteration", "?")
|
||||
thought = step.get("thought", "")
|
||||
actions = step.get("actions", [])
|
||||
observations = step.get("observations", [])
|
||||
|
||||
|
||||
lines.append(f"\n--- 迭代 {iteration} ---")
|
||||
if thought:
|
||||
lines.append(f"思考: {thought[:200]}...")
|
||||
|
||||
|
||||
if actions:
|
||||
lines.append("行动:")
|
||||
for action in actions:
|
||||
action_type = action.get("action_type", "unknown")
|
||||
action_params = action.get("action_params", {})
|
||||
lines.append(f" - {action_type}: {json.dumps(action_params, ensure_ascii=False)}")
|
||||
|
||||
|
||||
if observations:
|
||||
lines.append("观察:")
|
||||
for obs in observations:
|
||||
@@ -231,7 +233,7 @@ def format_thinking_steps(thinking_steps: list) -> str:
|
||||
if len(str(obs)) > 200:
|
||||
obs_str += "..."
|
||||
lines.append(f" - {obs_str}")
|
||||
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@@ -242,31 +244,32 @@ async def test_memory_retrieval(
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""测试记忆检索功能
|
||||
|
||||
|
||||
Args:
|
||||
question: 要查询的问题
|
||||
chat_id: 聊天ID
|
||||
context: 上下文信息
|
||||
max_iterations: 最大迭代次数
|
||||
|
||||
|
||||
Returns:
|
||||
包含测试结果的字典
|
||||
"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f"[测试] 记忆检索测试")
|
||||
print("[测试] 记忆检索测试")
|
||||
print(f"[问题] {question}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 延迟导入并初始化记忆检索prompt(这会自动加载 global_config)
|
||||
# 注意:必须在函数内部调用,避免在模块级别触发循环导入
|
||||
try:
|
||||
init_memory_retrieval_prompt, _react_agent_solve_question, _ = _import_memory_retrieval()
|
||||
|
||||
|
||||
# 检查 prompt 是否已经初始化,避免重复初始化
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
|
||||
if "memory_retrieval_question_prompt" not in global_prompt_manager._prompts:
|
||||
init_memory_retrieval_prompt()
|
||||
else:
|
||||
@@ -274,24 +277,24 @@ async def test_memory_retrieval(
|
||||
except Exception as e:
|
||||
logger.error(f"初始化记忆检索模块失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# 获取 global_config(此时应该已经加载)
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# 直接调用 _react_agent_solve_question 来获取详细的迭代信息
|
||||
if max_iterations is None:
|
||||
max_iterations = global_config.memory.max_agent_iterations
|
||||
|
||||
|
||||
timeout = global_config.memory.agent_timeout_seconds
|
||||
|
||||
print(f"\n[配置]")
|
||||
|
||||
print("\n[配置]")
|
||||
print(f" 最大迭代次数: {max_iterations}")
|
||||
print(f" 超时时间: {timeout}秒")
|
||||
print(f" 聊天ID: {chat_id}")
|
||||
|
||||
|
||||
# 执行检索
|
||||
print(f"\n[开始检索] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
|
||||
|
||||
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
|
||||
question=question,
|
||||
chat_id=chat_id,
|
||||
@@ -299,14 +302,14 @@ async def test_memory_retrieval(
|
||||
timeout=timeout,
|
||||
initial_info="",
|
||||
)
|
||||
|
||||
|
||||
# 记录结束时间
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
|
||||
# 获取token使用情况
|
||||
token_usage = get_token_usage_since(start_time)
|
||||
|
||||
|
||||
# 构建结果
|
||||
result = {
|
||||
"question": question,
|
||||
@@ -318,41 +321,41 @@ async def test_memory_retrieval(
|
||||
"iteration_count": len(thinking_steps),
|
||||
"token_usage": token_usage,
|
||||
}
|
||||
|
||||
|
||||
# 输出结果
|
||||
print(f"\n[检索完成] {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}")
|
||||
print(f"\n[结果]")
|
||||
print("\n[结果]")
|
||||
print(f" 是否找到答案: {'是' if found_answer else '否'}")
|
||||
if found_answer and answer:
|
||||
print(f" 答案: {answer}")
|
||||
else:
|
||||
print(f" 答案: (未找到答案)")
|
||||
print(" 答案: (未找到答案)")
|
||||
print(f" 是否超时: {'是' if is_timeout else '否'}")
|
||||
print(f" 迭代次数: {len(thinking_steps)}")
|
||||
print(f" 总耗时: {elapsed_time:.2f}秒")
|
||||
|
||||
print(f"\n[Token使用情况]")
|
||||
|
||||
print("\n[Token使用情况]")
|
||||
print(f" 总请求数: {token_usage['request_count']}")
|
||||
print(f" 总Prompt Tokens: {token_usage['total_prompt_tokens']:,}")
|
||||
print(f" 总Completion Tokens: {token_usage['total_completion_tokens']:,}")
|
||||
print(f" 总Tokens: {token_usage['total_tokens']:,}")
|
||||
print(f" 总成本: ${token_usage['total_cost']:.6f}")
|
||||
|
||||
if token_usage['model_usage']:
|
||||
print(f"\n[按模型统计]")
|
||||
for model_name, usage in token_usage['model_usage'].items():
|
||||
|
||||
if token_usage["model_usage"]:
|
||||
print("\n[按模型统计]")
|
||||
for model_name, usage in token_usage["model_usage"].items():
|
||||
print(f" {model_name}:")
|
||||
print(f" 请求数: {usage['request_count']}")
|
||||
print(f" Prompt Tokens: {usage['prompt_tokens']:,}")
|
||||
print(f" Completion Tokens: {usage['completion_tokens']:,}")
|
||||
print(f" 总Tokens: {usage['total_tokens']:,}")
|
||||
print(f" 成本: ${usage['cost']:.6f}")
|
||||
|
||||
print(f"\n[迭代详情]")
|
||||
|
||||
print("\n[迭代详情]")
|
||||
print(format_thinking_steps(thinking_steps))
|
||||
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -375,12 +378,12 @@ def main() -> None:
|
||||
"-o",
|
||||
help="将结果保存到JSON文件(可选)",
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# 初始化日志(使用较低的详细程度,避免输出过多日志)
|
||||
initialize_logging(verbose=False)
|
||||
|
||||
|
||||
# 交互式输入问题
|
||||
print("\n" + "=" * 80)
|
||||
print("记忆检索测试工具")
|
||||
@@ -389,7 +392,7 @@ def main() -> None:
|
||||
if not question:
|
||||
print("错误: 问题不能为空")
|
||||
return
|
||||
|
||||
|
||||
# 交互式输入最大迭代次数
|
||||
max_iterations_input = input("\n请输入最大迭代次数(直接回车使用配置默认值): ").strip()
|
||||
max_iterations = None
|
||||
@@ -402,7 +405,7 @@ def main() -> None:
|
||||
except ValueError:
|
||||
print("警告: 无效的迭代次数,将使用配置默认值")
|
||||
max_iterations = None
|
||||
|
||||
|
||||
# 连接数据库
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -410,7 +413,7 @@ def main() -> None:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
print(f"错误: 数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 运行测试
|
||||
try:
|
||||
result = asyncio.run(
|
||||
@@ -421,7 +424,7 @@ def main() -> None:
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 如果指定了输出文件,保存结果
|
||||
if args.output:
|
||||
# 将thinking_steps转换为可序列化的格式
|
||||
@@ -429,7 +432,7 @@ def main() -> None:
|
||||
with open(args.output, "w", encoding="utf-8") as f:
|
||||
json.dump(output_result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n[结果已保存] {args.output}")
|
||||
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n[中断] 用户中断测试")
|
||||
except Exception as e:
|
||||
@@ -444,4 +447,3 @@ def main() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user