better:优化了表达方式采样

This commit is contained in:
SengokuCola
2025-10-14 12:36:23 +08:00
parent d5f17b1f89
commit cb500e069a
5 changed files with 128 additions and 77 deletions

View File

@@ -2,6 +2,7 @@ import json
import time
import random
import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
@@ -12,6 +13,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.style_learner import style_learner_manager
from src.express.express_utils import filter_message_content, weighted_sample
logger = get_logger("expression_selector")
@@ -44,29 +46,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""随机抽样"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
chosen_idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(chosen_idx))
return selected
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
@@ -149,6 +128,9 @@ class ExpressionSelector:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
@@ -160,7 +142,7 @@ class ExpressionSelector:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, target_message, top_k=total_num
related_chat_id, filtered_target_message, top_k=total_num
)
if best_style and scores:
@@ -186,7 +168,7 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": target_message
"prediction_input": filtered_target_message
})
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")