better:优化了表达方式采样
This commit is contained in:
@@ -2,6 +2,7 @@ import json
|
||||
import time
|
||||
import random
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
@@ -12,6 +13,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.express.style_learner import style_learner_manager
|
||||
from src.express.express_utils import filter_message_content, weighted_sample
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
@@ -44,29 +46,6 @@ def init_prompt():
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""随机抽样"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
# 使用随机抽样
|
||||
selected = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(k):
|
||||
if not population_copy:
|
||||
break
|
||||
|
||||
# 随机选择一个元素
|
||||
chosen_idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(chosen_idx))
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
@@ -149,6 +128,9 @@ class ExpressionSelector:
|
||||
List[Dict[str, Any]]: 预测的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 过滤目标消息内容,移除回复、表情包等特殊格式
|
||||
filtered_target_message = filter_message_content(target_message)
|
||||
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
@@ -160,7 +142,7 @@ class ExpressionSelector:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, target_message, top_k=total_num
|
||||
related_chat_id, filtered_target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
@@ -186,7 +168,7 @@ class ExpressionSelector:
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": target_message
|
||||
"prediction_input": filtered_target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
Reference in New Issue
Block a user