better:优化了表达模型的预测,和表达方式的学习逻辑
This commit is contained in:
@@ -136,13 +136,13 @@ class ExpressionSelector:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def get_model_predicted_expressions(self, chat_id: str, chat_info: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
def get_model_predicted_expressions(self, chat_id: str, target_message: str, total_num: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用 style_learner 模型预测最合适的表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
chat_info: 聊天内容信息
|
||||
target_message: 目标消息内容
|
||||
total_num: 需要预测的数量
|
||||
|
||||
Returns:
|
||||
@@ -152,10 +152,7 @@ class ExpressionSelector:
|
||||
# 支持多chat_id合并预测
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 从聊天信息中提取关键内容作为预测输入
|
||||
# 这里可以进一步优化,提取更合适的预测输入
|
||||
prediction_input = self._extract_prediction_input(chat_info)
|
||||
|
||||
|
||||
predicted_expressions = []
|
||||
|
||||
# 为每个相关的chat_id进行预测
|
||||
@@ -163,7 +160,7 @@ class ExpressionSelector:
|
||||
try:
|
||||
# 使用 style_learner 预测最合适的风格
|
||||
best_style, scores = style_learner_manager.predict_style(
|
||||
related_chat_id, prediction_input, top_k=total_num
|
||||
related_chat_id, target_message, top_k=total_num
|
||||
)
|
||||
|
||||
if best_style and scores:
|
||||
@@ -175,7 +172,6 @@ class ExpressionSelector:
|
||||
# 从数据库查找对应的表达记录
|
||||
expr_query = Expression.select().where(
|
||||
(Expression.chat_id == related_chat_id) &
|
||||
(Expression.type == "style") &
|
||||
(Expression.situation == situation) &
|
||||
(Expression.style == best_style)
|
||||
)
|
||||
@@ -188,11 +184,12 @@ class ExpressionSelector:
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"prediction_score": scores.get(best_style, 0.0),
|
||||
"prediction_input": prediction_input
|
||||
"prediction_input": target_message
|
||||
})
|
||||
else:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {best_style} 没有找到对应的表达方式")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"为聊天室 {related_chat_id} 预测表达方式失败: {e}")
|
||||
@@ -208,39 +205,11 @@ class ExpressionSelector:
|
||||
except Exception as e:
|
||||
logger.error(f"模型预测表达方式失败: {e}")
|
||||
# 如果预测失败,回退到随机选择
|
||||
return self._fallback_random_expressions(chat_id, total_num)
|
||||
return self._random_expressions(chat_id, total_num)
|
||||
|
||||
def _extract_prediction_input(self, chat_info: str) -> str:
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从聊天信息中提取用于预测的关键内容
|
||||
|
||||
Args:
|
||||
chat_info: 聊天内容信息
|
||||
|
||||
Returns:
|
||||
str: 提取的预测输入
|
||||
"""
|
||||
try:
|
||||
# 简单的提取策略:取最后几句话作为预测输入
|
||||
lines = chat_info.strip().split('\n')
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# 取最后3行作为预测输入
|
||||
recent_lines = lines[-1:]
|
||||
prediction_input = ' '.join(recent_lines).strip()
|
||||
|
||||
logger.info(f"提取预测输入: {prediction_input}")
|
||||
|
||||
return prediction_input
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"提取预测输入失败: {e}")
|
||||
return ""
|
||||
|
||||
def _fallback_random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
回退到随机选择表达方式
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
@@ -255,7 +224,7 @@ class ExpressionSelector:
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
|
||||
(Expression.chat_id.in_(related_chat_ids))
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
@@ -265,7 +234,6 @@ class ExpressionSelector:
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"type": "style",
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
for expr in style_query
|
||||
@@ -277,7 +245,7 @@ class ExpressionSelector:
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
logger.info(f"回退到随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
logger.info(f"随机选择,为聊天室 {chat_id} 选择了 {len(selected_style)} 个表达方式")
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
@@ -315,7 +283,7 @@ class ExpressionSelector:
|
||||
if expression_mode == "exp_model":
|
||||
# exp_model模式:直接使用模型预测,不经过LLM
|
||||
logger.debug(f"使用exp_model模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_model_only(chat_id, chat_info, max_num)
|
||||
return await self._select_expressions_model_only(chat_id, target_message, max_num)
|
||||
elif expression_mode == "classic":
|
||||
# classic模式:随机选择+LLM选择
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
@@ -327,7 +295,7 @@ class ExpressionSelector:
|
||||
async def _select_expressions_model_only(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
target_message: str,
|
||||
max_num: int = 10,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
@@ -335,7 +303,7 @@ class ExpressionSelector:
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
target_message: 目标消息内容
|
||||
max_num: 最大选择数量
|
||||
|
||||
Returns:
|
||||
@@ -343,11 +311,7 @@ class ExpressionSelector:
|
||||
"""
|
||||
try:
|
||||
# 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, max_num * 2)
|
||||
|
||||
|
||||
# 直接取前max_num个预测结果
|
||||
selected_expressions = style_exprs[:max_num]
|
||||
selected_expressions = self.get_model_predicted_expressions(chat_id, target_message, max_num)
|
||||
selected_ids = [expr["id"] for expr in selected_expressions]
|
||||
|
||||
# 更新last_active_time
|
||||
@@ -381,8 +345,8 @@ class ExpressionSelector:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用模型预测最合适的表达方式
|
||||
style_exprs = self.get_model_predicted_expressions(chat_id, chat_info, 20)
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
@@ -460,21 +424,6 @@ class ExpressionSelector:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
async def select_suitable_expressions_llm(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
使用LLM选择适合的表达方式(保持向后兼容)
|
||||
|
||||
注意:此方法已被 select_suitable_expressions 替代,建议使用新方法
|
||||
"""
|
||||
logger.warning("select_suitable_expressions_llm 方法已过时,请使用 select_suitable_expressions")
|
||||
return await self.select_suitable_expressions(chat_id, chat_info, max_num, target_message)
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
@@ -482,19 +431,17 @@ class ExpressionSelector:
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
expr_type: str = expr.get("type", "style")
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, expr_type, situation, style)
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, expr_type, situation, style in updates_by_key:
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id)
|
||||
& (Expression.type == expr_type)
|
||||
& (Expression.situation == situation)
|
||||
& (Expression.style == style)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user