This commit is contained in:
墨梓柒
2025-11-13 13:24:55 +08:00
parent e78a070fbd
commit 7839acd25d
52 changed files with 1322 additions and 1408 deletions

View File

@@ -1,8 +1,6 @@
import json
import time
import random
import hashlib
import re
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
@@ -115,30 +113,31 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
def get_model_predicted_expressions(
self, chat_id: str, target_message: str, total_num: int = 10
) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
List[Dict[str, Any]]: 预测的表达方式列表
"""
try:
# 过滤目标消息内容,移除回复、表情包等特殊格式
filtered_target_message = filter_message_content(target_message)
logger.info(f"{chat_id} 预测表达方式,过滤后的目标消息内容: {filtered_target_message}")
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
predicted_expressions = []
# 为每个相关的chat_id进行预测
for related_chat_id in related_chat_ids:
try:
@@ -146,59 +145,65 @@ class ExpressionSelector:
best_style, scores = style_learner_manager.predict_style(
related_chat_id, filtered_target_message, top_k=total_num
)
if best_style and scores:
# 获取预测风格的完整信息
learner = style_learner_manager.get_learner(related_chat_id)
style_id, situation = learner.get_style_info(best_style)
if style_id and situation:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.situation == situation) &
(Expression.style == best_style)
(Expression.chat_id == related_chat_id)
& (Expression.situation == situation)
& (Expression.style == best_style)
)
if expr_query.exists():
expr = expr_query.get()
predicted_expressions.append({
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message
})
predicted_expressions.append(
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date
if expr.create_date is not None
else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": filtered_target_message,
}
)
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
logger.warning(
f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式"
)
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
continue
# 按预测分数排序,取前 total_num 个
predicted_expressions.sort(key=lambda x: x.get("prediction_score", 0.0), reverse=True)
selected_expressions = predicted_expressions[:total_num]
logger.info(f"{chat_id} 预测到 {len(selected_expressions)} 个表达方式")
return selected_expressions
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._random_expressions(chat_id, total_num)
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
@@ -207,9 +212,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids))
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)))
style_exprs = [
{
@@ -228,15 +231,14 @@ class ExpressionSelector:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
@@ -246,13 +248,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
根据配置模式选择适合的表达方式
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@@ -263,7 +265,7 @@ class ExpressionSelector:
# 获取配置模式
expression_mode = global_config.expression.mode
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
@@ -284,12 +286,12 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
exp_model模式直接使用模型预测不经过LLM
Args:
chat_id: 聊天流ID
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@@ -297,14 +299,14 @@ class ExpressionSelector:
# 使用模型预测最合适的表达方式
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
if selected_expressions:
self.update_expressions_last_active_time(selected_expressions)
logger.info(f"exp_model模式为聊天流 {chat_id} 选择了 {len(selected_expressions)} 个表达方式")
return selected_expressions, selected_ids
except Exception as e:
logger.error(f"exp_model模式选择表达方式失败: {e}")
return [], []
@@ -318,13 +320,13 @@ class ExpressionSelector:
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
@@ -425,17 +427,13 @@ class ExpressionSelector:
updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.situation == situation)
& (Expression.style == style)
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug(
"表达方式激活: 更新last_active_time in db"
)
logger.debug("表达方式激活: 更新last_active_time in db")
init_prompt()