better:优化了表达模型的预测,和表达方式的学习逻辑

This commit is contained in:
SengokuCola
2025-10-11 23:44:52 +08:00
parent 4a074ec374
commit d073a215e3
8 changed files with 113 additions and 410 deletions

View File

@@ -136,13 +136,13 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
"""
使用 style_learner 模型预测最合适的表达方式
Args:
chat_id: 聊天室ID
chat_info: 聊天内容信息
target_message: 目标消息内容
total_num: 需要预测的数量
Returns:
@@ -152,10 +152,7 @@ class ExpressionSelector:
# 支持多chat_id合并预测
related_chat_ids = self.get_related_chat_ids(chat_id)
# 从聊天信息中提取关键内容作为预测输入
# 这里可以进一步优化,提取更合适的预测输入
prediction_input = self._extract_prediction_input(chat_info)
predicted_expressions = []
# 为每个相关的chat_id进行预测
@@ -163,7 +160,7 @@ class ExpressionSelector:
try:
# 使用 style_learner 预测最合适的风格
best_style, scores = style_learner_manager.predict_style(
related_chat_id, prediction_input, top_k=total_num
related_chat_id, target_message, top_k=total_num
)
if best_style and scores:
@@ -175,7 +172,6 @@ class ExpressionSelector:
# 从数据库查找对应的表达记录
expr_query = Expression.select().where(
(Expression.chat_id == related_chat_id) &
(Expression.type == "style") &
(Expression.situation == situation) &
(Expression.style == best_style)
)
@@ -188,11 +184,12 @@ class ExpressionSelector:
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"prediction_score": scores.get(best_style, 0.0),
"prediction_input": prediction_input
"prediction_input": target_message
})
else:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
except Exception as e:
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
@@ -208,39 +205,11 @@ class ExpressionSelector:
except Exception as e:
logger.error(f"模型预测表达方式失败: {e}")
# 如果预测失败,回退到随机选择
return self._fallback_random_expressions(chat_id, total_num)
return self._random_expressions(chat_id, total_num)
def _extract_prediction_input(self, chat_info: str) -> str:
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
从聊天信息中提取用于预测的关键内容
Args:
chat_info: 聊天内容信息
Returns:
str: 提取的预测输入
"""
try:
# 简单的提取策略:取最后几句话作为预测输入
lines = chat_info.strip().split('\n')
if not lines:
return ""
# 取最后3行作为预测输入
recent_lines = lines[-1:]
prediction_input = ' '.join(recent_lines).strip()
logger.info(f"提取预测输入: {prediction_input}")
return prediction_input
except Exception as e:
logger.warning(f"提取预测输入失败: {e}")
return ""
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
回退到随机选择表达方式
随机选择表达方式
Args:
chat_id: 聊天室ID
@@ -255,7 +224,7 @@ class ExpressionSelector:
# 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
(Expression.chat_id.in_(related_chat_ids))
)
style_exprs = [
@@ -265,7 +234,6 @@ class ExpressionSelector:
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
}
for expr in style_query
@@ -277,7 +245,7 @@ class ExpressionSelector:
else:
selected_style = []
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
return selected_style
except Exception as e:
@@ -315,7 +283,7 @@ class ExpressionSelector:
if expression_mode == "exp_model":
# exp_model模式直接使用模型预测不经过LLM
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
return await self._select_expressions_model_only(chat_id, target_message, max_num)
elif expression_mode == "classic":
# classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
@@ -327,7 +295,7 @@ class ExpressionSelector:
async def _select_expressions_model_only(
self,
chat_id: str,
chat_info: str,
target_message: str,
max_num: int = 10,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
@@ -335,7 +303,7 @@ class ExpressionSelector:
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
target_message: 目标消息内容
max_num: 最大选择数量
Returns:
@@ -343,11 +311,7 @@ class ExpressionSelector:
"""
try:
# 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
# 直接取前max_num个预测结果
selected_expressions = style_exprs[:max_num]
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
selected_ids = [expr["id"] for expr in selected_expressions]
# 更新last_active_time
@@ -381,8 +345,8 @@ class ExpressionSelector:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用模型预测最合适的表达方式
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
# 1. 使用随机抽样选择表达方式
style_exprs = self._random_expressions(chat_id, 20)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
@@ -460,21 +424,6 @@ class ExpressionSelector:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
async def select_suitable_expressions_llm(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
使用LLM选择适合的表达方式保持向后兼容
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
"""
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
@@ -482,19 +431,17 @@ class ExpressionSelector:
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
expr_type: str = expr.get("type", "style")
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
key = (source_id, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, expr_type, situation, style in updates_by_key:
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id)
& (Expression.type == expr_type)
& (Expression.situation == situation)
& (Expression.style == style)
)