优化对上下文的压缩,新增表达方式快速版本
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
import json
|
||||
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
@@ -30,7 +31,7 @@ class MaisakaExpressionSelectionResult:
|
||||
|
||||
|
||||
class MaisakaExpressionSelector:
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理选择。"""
|
||||
"""负责在 replyer 侧完成表达方式筛选与子代理二次选择。"""
|
||||
|
||||
def _can_use_expressions(self, session_id: str) -> bool:
|
||||
try:
|
||||
@@ -40,6 +41,13 @@ class MaisakaExpressionSelector:
|
||||
logger.error(f"检查表达方式使用开关失败: {exc}")
|
||||
return False
|
||||
|
||||
def _can_use_advanced_chosen(self, session_id: str) -> bool:
|
||||
try:
|
||||
return ExpressionConfigUtils.get_expression_advanced_chosen_for_chat(session_id)
|
||||
except Exception as exc:
|
||||
logger.error(f"检查表达方式二次选择开关失败: {exc}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_global_expression_group_marker(platform: str, item_id: str) -> bool:
|
||||
return platform == "*" and item_id == "*"
|
||||
@@ -101,7 +109,7 @@ class MaisakaExpressionSelector:
|
||||
"id": expression.id,
|
||||
"situation": expression.situation,
|
||||
"style": expression.style,
|
||||
"count": expression.count if getattr(expression, "count", None) is not None else 1,
|
||||
"count": expression.count if expression.count is not None else 1,
|
||||
}
|
||||
for expression in expressions
|
||||
if expression.id is not None and expression.situation and expression.style
|
||||
@@ -185,7 +193,7 @@ class MaisakaExpressionSelector:
|
||||
"你只负责根据最近聊天上下文,为这一次可见回复挑选最合适的表达方式。\n"
|
||||
"请只从下面候选中选择 0 到 3 条最适合当前语境的表达方式。\n"
|
||||
"优先考虑自然、贴合上下文、不生硬、不模板化。\n"
|
||||
"如果没有明显合适的,就返回空列表。\n"
|
||||
"如果没有明显合适的,就返回空数组。\n"
|
||||
'严格只输出 JSON,对象格式为 {"selected_ids":[123,456]}。\n\n'
|
||||
f"最近上下文:\n{history_block}\n\n"
|
||||
f"目标消息:{target_text or '无'}\n"
|
||||
@@ -222,6 +230,32 @@ class MaisakaExpressionSelector:
|
||||
break
|
||||
return selected_ids
|
||||
|
||||
def _build_direct_selection_result(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
candidates: List[dict[str, Any]],
|
||||
) -> MaisakaExpressionSelectionResult:
|
||||
selected_ids = [
|
||||
candidate["id"]
|
||||
for candidate in candidates
|
||||
if isinstance(candidate.get("id"), int)
|
||||
]
|
||||
selected_expressions = [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if candidate.get("id") in selected_ids
|
||||
]
|
||||
self._update_last_active_time(selected_ids)
|
||||
logger.info(
|
||||
f"表达方式直接注入:session_id={session_id} 已选数={len(selected_ids)} "
|
||||
f"selected_ids={selected_ids!r} 已选预览={self._format_candidate_preview(selected_expressions)}"
|
||||
)
|
||||
return MaisakaExpressionSelectionResult(
|
||||
expression_habits=self._build_expression_habits_block(selected_expressions),
|
||||
selected_expression_ids=selected_ids,
|
||||
)
|
||||
|
||||
def _update_last_active_time(self, selected_ids: List[int]) -> None:
|
||||
if not selected_ids:
|
||||
return
|
||||
@@ -247,15 +281,22 @@ class MaisakaExpressionSelector:
|
||||
if not self._can_use_expressions(session_id):
|
||||
logger.info(f"表达方式选择已跳过:当前会话未启用表达方式,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
candidates = self._load_expression_candidates(session_id)
|
||||
if not candidates:
|
||||
logger.info(f"表达方式选择已跳过:本地候选不足,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
if not self._can_use_advanced_chosen(session_id):
|
||||
return self._build_direct_selection_result(
|
||||
session_id=session_id,
|
||||
candidates=candidates,
|
||||
)
|
||||
|
||||
if sub_agent_runner is None:
|
||||
logger.info(f"表达方式选择已跳过:缺少 sub_agent_runner,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
logger.info(
|
||||
f"表达方式选择开始:session_id={session_id} 候选数={len(candidates)} "
|
||||
f"候选预览={self._format_candidate_preview(candidates)}"
|
||||
@@ -273,10 +314,9 @@ class MaisakaExpressionSelector:
|
||||
logger.exception("表达方式选择子代理执行失败")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
# logger.info(f"表达方式子代理原始结果:session_id={session_id} response={raw_response!r}")
|
||||
selected_ids = self._parse_selected_ids(raw_response, candidates)
|
||||
if not selected_ids:
|
||||
logger.info(f"表达方式选择完成但未命中:session_id={session_id}")
|
||||
logger.info(f"表达方式选择完成但未命中,session_id={session_id}")
|
||||
return MaisakaExpressionSelectionResult()
|
||||
|
||||
selected_expressions = [candidate for candidate in candidates if candidate.get("id") in selected_ids]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("common_utils")
|
||||
|
||||
@@ -10,23 +10,14 @@ class TempMethodsExpression:
|
||||
"""用于临时存放一些方法的类"""
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
|
||||
"""
|
||||
def _find_expression_config_item(chat_stream_id: Optional[str] = None):
|
||||
if not global_config.expression.learning_list:
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
if chat_stream_id:
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
continue # 这是全局的
|
||||
continue
|
||||
stream_id = TempMethodsExpression._get_stream_id(
|
||||
config_item.platform,
|
||||
str(config_item.item_id),
|
||||
@@ -34,14 +25,44 @@ class TempMethodsExpression:
|
||||
)
|
||||
if stream_id is None:
|
||||
continue
|
||||
if stream_id == chat_stream_id:
|
||||
if stream_id != chat_stream_id:
|
||||
continue
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
for config_item in global_config.expression.learning_list:
|
||||
if not config_item.platform and not config_item.item_id:
|
||||
return config_item.use_expression, config_item.enable_learning, config_item.enable_jargon_learning
|
||||
return config_item
|
||||
|
||||
return True, True, True
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_expression_advanced_chosen_for_chat(chat_stream_id: Optional[str] = None) -> bool:
|
||||
"""根据聊天流 ID 获取表达方式是否启用二次选择。"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return False
|
||||
return config_item.advanced_chosen
|
||||
|
||||
@staticmethod
|
||||
def get_expression_config_for_chat(chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
|
||||
"""
|
||||
根据聊天流 ID 获取表达配置。
|
||||
|
||||
Args:
|
||||
chat_stream_id: 聊天流 ID,格式为哈希值
|
||||
|
||||
Returns:
|
||||
tuple: (是否使用表达, 是否学习表达, 是否启用 jargon 学习)
|
||||
"""
|
||||
config_item = TempMethodsExpression._find_expression_config_item(chat_stream_id)
|
||||
if config_item is None:
|
||||
return True, True, True
|
||||
|
||||
return (
|
||||
config_item.use_expression,
|
||||
config_item.enable_learning,
|
||||
config_item.enable_jargon_learning,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_stream_id(
|
||||
@@ -50,15 +71,15 @@ class TempMethodsExpression:
|
||||
is_group: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
根据平台、ID字符串和是否为群聊生成聊天流ID
|
||||
根据平台、ID 字符串和是否为群聊生成聊天流 ID。
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
id_str: 用户或群组的原始ID字符串
|
||||
id_str: 用户或群组的原始 ID 字符串
|
||||
is_group: 是否为群聊
|
||||
|
||||
Returns:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
str: 生成的聊天流 ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
@@ -68,5 +89,5 @@ class TempMethodsExpression:
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
logger.error(f"生成聊天流 ID 失败: {e}")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user