Ruff Format
This commit is contained in:
@@ -38,13 +38,13 @@ COUNT_ANALYSIS_FILE = os.path.join(TEMP_DIR, "count_analysis_evaluation_results.
|
||||
def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
"""
|
||||
加载已有的评估结果
|
||||
|
||||
|
||||
Returns:
|
||||
(已有结果列表, 已评估的项目(situation, style)元组集合)
|
||||
"""
|
||||
if not os.path.exists(COUNT_ANALYSIS_FILE):
|
||||
return [], set()
|
||||
|
||||
|
||||
try:
|
||||
with open(COUNT_ANALYSIS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@@ -61,22 +61,22 @@ def load_existing_results() -> tuple[List[Dict], Set[Tuple[str, str]]]:
|
||||
def save_results(evaluation_results: List[Dict]):
|
||||
"""
|
||||
保存评估结果到文件
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
try:
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
|
||||
data = {
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"total_count": len(evaluation_results),
|
||||
"evaluation_results": evaluation_results
|
||||
"evaluation_results": evaluation_results,
|
||||
}
|
||||
|
||||
|
||||
with open(COUNT_ANALYSIS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
logger.info(f"评估结果已保存到: {COUNT_ANALYSIS_FILE}")
|
||||
print(f"\n✓ 评估结果已保存(共 {len(evaluation_results)} 条)")
|
||||
except Exception as e:
|
||||
@@ -84,70 +84,70 @@ def save_results(evaluation_results: List[Dict]):
|
||||
print(f"\n✗ 保存评估结果失败: {e}")
|
||||
|
||||
|
||||
def select_expressions_for_evaluation(
|
||||
evaluated_pairs: Set[Tuple[str, str]] = None
|
||||
) -> List[Expression]:
|
||||
def select_expressions_for_evaluation(evaluated_pairs: Set[Tuple[str, str]] = None) -> List[Expression]:
|
||||
"""
|
||||
选择用于评估的表达方式
|
||||
选择所有count>1的项目,然后选择两倍数量的count=1的项目
|
||||
|
||||
|
||||
Args:
|
||||
evaluated_pairs: 已评估的项目集合,用于避免重复
|
||||
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
if evaluated_pairs is None:
|
||||
evaluated_pairs = set()
|
||||
|
||||
|
||||
try:
|
||||
# 查询所有表达方式
|
||||
all_expressions = list(Expression.select())
|
||||
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("数据库中没有表达方式记录")
|
||||
return []
|
||||
|
||||
|
||||
# 过滤出未评估的项目
|
||||
unevaluated = [
|
||||
expr for expr in all_expressions
|
||||
if (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
|
||||
unevaluated = [expr for expr in all_expressions if (expr.situation, expr.style) not in evaluated_pairs]
|
||||
|
||||
if not unevaluated:
|
||||
logger.warning("所有项目都已评估完成")
|
||||
return []
|
||||
|
||||
|
||||
# 按count分组
|
||||
count_eq1 = [expr for expr in unevaluated if expr.count == 1]
|
||||
count_gt1 = [expr for expr in unevaluated if expr.count > 1]
|
||||
|
||||
|
||||
logger.info(f"未评估项目中:count=1的有{len(count_eq1)}条,count>1的有{len(count_gt1)}条")
|
||||
|
||||
|
||||
# 选择所有count>1的项目
|
||||
selected_count_gt1 = count_gt1.copy()
|
||||
|
||||
|
||||
# 选择count=1的项目,数量为count>1数量的2倍
|
||||
count_gt1_count = len(selected_count_gt1)
|
||||
count_eq1_needed = count_gt1_count * 2
|
||||
|
||||
|
||||
if len(count_eq1) < count_eq1_needed:
|
||||
logger.warning(f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条")
|
||||
logger.warning(
|
||||
f"count=1的项目只有{len(count_eq1)}条,少于需要的{count_eq1_needed}条,将选择全部{len(count_eq1)}条"
|
||||
)
|
||||
count_eq1_needed = len(count_eq1)
|
||||
|
||||
|
||||
# 随机选择count=1的项目
|
||||
selected_count_eq1 = random.sample(count_eq1, count_eq1_needed) if count_eq1 and count_eq1_needed > 0 else []
|
||||
|
||||
|
||||
selected = selected_count_gt1 + selected_count_eq1
|
||||
random.shuffle(selected) # 打乱顺序
|
||||
|
||||
logger.info(f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"已选择{len(selected)}条表达方式:count>1的有{len(selected_count_gt1)}条(全部),count=1的有{len(selected_count_eq1)}条(2倍)"
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
@@ -155,11 +155,11 @@ def select_expressions_for_evaluation(
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
@@ -181,34 +181,32 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
|
||||
response, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.6,
|
||||
max_tokens=1024
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
@@ -218,13 +216,13 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
@@ -233,23 +231,25 @@ async def _single_llm_evaluation(situation: str, style: str, llm: LLMRequest) ->
|
||||
async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Dict:
|
||||
"""
|
||||
使用LLM评估单个表达方式
|
||||
|
||||
|
||||
Args:
|
||||
expression: 表达方式对象
|
||||
llm: LLM请求实例
|
||||
|
||||
|
||||
Returns:
|
||||
评估结果字典
|
||||
"""
|
||||
logger.info(f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}")
|
||||
|
||||
logger.info(
|
||||
f"开始评估表达方式: situation={expression.situation}, style={expression.style}, count={expression.count}"
|
||||
)
|
||||
|
||||
suitable, reason, error = await _single_llm_evaluation(expression.situation, expression.style, llm)
|
||||
|
||||
|
||||
if error:
|
||||
suitable = False
|
||||
|
||||
|
||||
logger.info(f"评估完成: {'通过' if suitable else '不通过'}")
|
||||
|
||||
|
||||
return {
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
@@ -258,28 +258,28 @@ async def llm_evaluate_expression(expression: Expression, llm: LLMRequest) -> Di
|
||||
"reason": reason,
|
||||
"error": error,
|
||||
"evaluator": "llm",
|
||||
"evaluated_at": datetime.now().isoformat()
|
||||
"evaluated_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
"""
|
||||
对评估结果进行统计分析
|
||||
|
||||
|
||||
Args:
|
||||
evaluation_results: 评估结果列表
|
||||
"""
|
||||
if not evaluation_results:
|
||||
print("\n没有评估结果可供分析")
|
||||
return
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("统计分析结果")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 按count分组统计
|
||||
count_groups = defaultdict(lambda: {"total": 0, "suitable": 0, "unsuitable": 0})
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
@@ -288,7 +288,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
count_groups[count]["suitable"] += 1
|
||||
else:
|
||||
count_groups[count]["unsuitable"] += 1
|
||||
|
||||
|
||||
# 显示每个count的统计
|
||||
print("\n【按count分组统计】")
|
||||
print("-" * 60)
|
||||
@@ -298,21 +298,21 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
suitable = group["suitable"]
|
||||
unsuitable = group["unsuitable"]
|
||||
pass_rate = (suitable / total * 100) if total > 0 else 0
|
||||
|
||||
|
||||
print(f"Count = {count}:")
|
||||
print(f" 总数: {total}")
|
||||
print(f" 通过: {suitable} ({pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100-pass_rate:.2f}%)")
|
||||
print(f" 不通过: {unsuitable} ({100 - pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 比较count=1和count>1
|
||||
count_eq1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
count_gt1_group = {"total": 0, "suitable": 0, "unsuitable": 0}
|
||||
|
||||
|
||||
for result in evaluation_results:
|
||||
count = result.get("count", 1)
|
||||
suitable = result.get("suitable", False)
|
||||
|
||||
|
||||
if count == 1:
|
||||
count_eq1_group["total"] += 1
|
||||
if suitable:
|
||||
@@ -325,34 +325,34 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
count_gt1_group["suitable"] += 1
|
||||
else:
|
||||
count_gt1_group["unsuitable"] += 1
|
||||
|
||||
|
||||
print("\n【Count=1 vs Count>1 对比】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
eq1_total = count_eq1_group["total"]
|
||||
eq1_suitable = count_eq1_group["suitable"]
|
||||
eq1_pass_rate = (eq1_suitable / eq1_total * 100) if eq1_total > 0 else 0
|
||||
|
||||
|
||||
gt1_total = count_gt1_group["total"]
|
||||
gt1_suitable = count_gt1_group["suitable"]
|
||||
gt1_pass_rate = (gt1_suitable / gt1_total * 100) if gt1_total > 0 else 0
|
||||
|
||||
|
||||
print("Count = 1:")
|
||||
print(f" 总数: {eq1_total}")
|
||||
print(f" 通过: {eq1_suitable} ({eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100-eq1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {eq1_total - eq1_suitable} ({100 - eq1_pass_rate:.2f}%)")
|
||||
print()
|
||||
print("Count > 1:")
|
||||
print(f" 总数: {gt1_total}")
|
||||
print(f" 通过: {gt1_suitable} ({gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100-gt1_pass_rate:.2f}%)")
|
||||
print(f" 不通过: {gt1_total - gt1_suitable} ({100 - gt1_pass_rate:.2f}%)")
|
||||
print()
|
||||
|
||||
|
||||
# 进行卡方检验(简化版,使用2x2列联表)
|
||||
if eq1_total > 0 and gt1_total > 0:
|
||||
print("【统计显著性检验】")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
# 构建2x2列联表
|
||||
# 通过 不通过
|
||||
# count=1 a b
|
||||
@@ -361,7 +361,7 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
b = eq1_total - eq1_suitable
|
||||
c = gt1_suitable
|
||||
d = gt1_total - gt1_suitable
|
||||
|
||||
|
||||
# 计算卡方统计量(简化版,使用Pearson卡方检验)
|
||||
n = eq1_total + gt1_total
|
||||
if n > 0:
|
||||
@@ -370,13 +370,13 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
e_b = (eq1_total * (b + d)) / n
|
||||
e_c = (gt1_total * (a + c)) / n
|
||||
e_d = (gt1_total * (b + d)) / n
|
||||
|
||||
|
||||
# 检查期望频数是否足够大(卡方检验要求每个期望频数>=5)
|
||||
min_expected = min(e_a, e_b, e_c, e_d)
|
||||
if min_expected < 5:
|
||||
print("警告:期望频数小于5,卡方检验可能不准确")
|
||||
print("建议使用Fisher精确检验")
|
||||
|
||||
|
||||
# 计算卡方值
|
||||
chi_square = 0
|
||||
if e_a > 0:
|
||||
@@ -387,26 +387,26 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
chi_square += ((c - e_c) ** 2) / e_c
|
||||
if e_d > 0:
|
||||
chi_square += ((d - e_d) ** 2) / e_d
|
||||
|
||||
|
||||
# 自由度 = (行数-1) * (列数-1) = 1
|
||||
df = 1
|
||||
|
||||
|
||||
# 临界值(α=0.05)
|
||||
chi_square_critical_005 = 3.841
|
||||
chi_square_critical_001 = 6.635
|
||||
|
||||
|
||||
print(f"卡方统计量: {chi_square:.4f}")
|
||||
print(f"自由度: {df}")
|
||||
print(f"临界值 (α=0.05): {chi_square_critical_005}")
|
||||
print(f"临界值 (α=0.01): {chi_square_critical_001}")
|
||||
|
||||
|
||||
if chi_square >= chi_square_critical_001:
|
||||
print("结论: 在α=0.01水平下,count=1和count>1的合格率存在显著差异(p<0.01)")
|
||||
elif chi_square >= chi_square_critical_005:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率存在显著差异(p<0.05)")
|
||||
else:
|
||||
print("结论: 在α=0.05水平下,count=1和count>1的合格率不存在显著差异(p≥0.05)")
|
||||
|
||||
|
||||
# 计算差异大小
|
||||
diff = abs(eq1_pass_rate - gt1_pass_rate)
|
||||
print(f"\n合格率差异: {diff:.2f}%")
|
||||
@@ -420,16 +420,16 @@ def perform_statistical_analysis(evaluation_results: List[Dict]):
|
||||
print("数据不足,无法进行统计检验")
|
||||
else:
|
||||
print("数据不足,无法进行count=1和count>1的对比分析")
|
||||
|
||||
|
||||
# 保存统计分析结果
|
||||
analysis_result = {
|
||||
"analysis_time": datetime.now().isoformat(),
|
||||
"count_groups": {str(k): v for k, v in count_groups.items()},
|
||||
"count_eq1": count_eq1_group,
|
||||
"count_gt1": count_gt1_group,
|
||||
"total_evaluated": len(evaluation_results)
|
||||
"total_evaluated": len(evaluation_results),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
analysis_file = os.path.join(TEMP_DIR, "count_analysis_statistics.json")
|
||||
with open(analysis_file, "w", encoding="utf-8") as f:
|
||||
@@ -444,7 +444,7 @@ async def main():
|
||||
logger.info("=" * 60)
|
||||
logger.info("开始表达方式按count分组的LLM评估和统计分析")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
# 初始化数据库连接
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
@@ -452,97 +452,95 @@ async def main():
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
# 加载已有评估结果
|
||||
existing_results, evaluated_pairs = load_existing_results()
|
||||
evaluation_results = existing_results.copy()
|
||||
|
||||
|
||||
if evaluated_pairs:
|
||||
print(f"\n已加载 {len(existing_results)} 条已有评估结果")
|
||||
print(f"已评估项目数: {len(evaluated_pairs)}")
|
||||
|
||||
|
||||
# 检查是否需要继续评估(检查是否还有未评估的count>1项目)
|
||||
# 先查询未评估的count>1项目数量
|
||||
try:
|
||||
all_expressions = list(Expression.select())
|
||||
unevaluated_count_gt1 = [
|
||||
expr for expr in all_expressions
|
||||
if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
expr for expr in all_expressions if expr.count > 1 and (expr.situation, expr.style) not in evaluated_pairs
|
||||
]
|
||||
has_unevaluated = len(unevaluated_count_gt1) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"查询未评估项目失败: {e}")
|
||||
has_unevaluated = False
|
||||
|
||||
|
||||
if has_unevaluated:
|
||||
print("\n" + "=" * 60)
|
||||
print("开始LLM评估")
|
||||
print("=" * 60)
|
||||
print("评估结果会自动保存到文件\n")
|
||||
|
||||
|
||||
# 创建LLM实例
|
||||
print("创建LLM实例...")
|
||||
try:
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="expression_evaluator_count_analysis_llm"
|
||||
request_type="expression_evaluator_count_analysis_llm",
|
||||
)
|
||||
print("✓ LLM实例创建成功\n")
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
print(f"\n✗ 创建LLM实例失败: {e}")
|
||||
db.close()
|
||||
return
|
||||
|
||||
|
||||
# 选择需要评估的表达方式(选择所有count>1的项目,然后选择两倍数量的count=1的项目)
|
||||
expressions = select_expressions_for_evaluation(
|
||||
evaluated_pairs=evaluated_pairs
|
||||
)
|
||||
|
||||
expressions = select_expressions_for_evaluation(evaluated_pairs=evaluated_pairs)
|
||||
|
||||
if not expressions:
|
||||
print("\n没有可评估的项目")
|
||||
else:
|
||||
print(f"\n已选择 {len(expressions)} 条表达方式进行评估")
|
||||
print(f"其中 count>1 的有 {sum(1 for e in expressions if e.count > 1)} 条")
|
||||
print(f"其中 count=1 的有 {sum(1 for e in expressions if e.count == 1)} 条\n")
|
||||
|
||||
|
||||
batch_results = []
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
print(f"LLM评估进度: {i}/{len(expressions)}")
|
||||
print(f" Situation: {expression.situation}")
|
||||
print(f" Style: {expression.style}")
|
||||
print(f" Count: {expression.count}")
|
||||
|
||||
|
||||
llm_result = await llm_evaluate_expression(expression, llm)
|
||||
|
||||
|
||||
print(f" 结果: {'通过' if llm_result['suitable'] else '不通过'}")
|
||||
if llm_result.get('error'):
|
||||
if llm_result.get("error"):
|
||||
print(f" 错误: {llm_result['error']}")
|
||||
print()
|
||||
|
||||
|
||||
batch_results.append(llm_result)
|
||||
# 使用 (situation, style) 作为唯一标识
|
||||
evaluated_pairs.add((llm_result["situation"], llm_result["style"]))
|
||||
|
||||
|
||||
# 添加延迟以避免API限流
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
|
||||
# 将当前批次结果添加到总结果中
|
||||
evaluation_results.extend(batch_results)
|
||||
|
||||
|
||||
# 保存结果
|
||||
save_results(evaluation_results)
|
||||
else:
|
||||
print(f"\n所有count>1的项目都已评估完成,已有 {len(evaluation_results)} 条评估结果")
|
||||
|
||||
|
||||
# 进行统计分析
|
||||
if len(evaluation_results) > 0:
|
||||
perform_statistical_analysis(evaluation_results)
|
||||
else:
|
||||
print("\n没有评估结果可供分析")
|
||||
|
||||
|
||||
# 关闭数据库连接
|
||||
try:
|
||||
db.close()
|
||||
@@ -553,4 +551,3 @@ async def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user