feat:修复表达方式的学习和使用,用subagent使用表达

1
This commit is contained in:
SengokuCola
2026-04-04 23:18:21 +08:00
parent 2fb911a8d5
commit 7b924774be
10 changed files with 497 additions and 569 deletions

View File

@@ -159,7 +159,8 @@ class ExpressionLearner:
self._learning_lock = asyncio.Lock()
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
self._last_processed_index = 0
self.min_messages_for_extraction = 10
@staticmethod
def _get_runtime_manager() -> Any:
@@ -265,27 +266,25 @@ class ExpressionLearner:
normalized_entries.append((content, source_id))
return normalized_entries
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
def get_pending_count(self, message_cache: List["SessionMessage"]) -> int:
"""??????????????"""
return max(0, len(message_cache) - self._last_processed_index)
def get_cache_size(self) -> int:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(
self,
message_cache: List["SessionMessage"],
jargon_miner: Optional["JargonMiner"] = None,
) -> bool:
"""?????????????????????"""
pending_messages = message_cache[self._last_processed_index :]
if not pending_messages:
logger.debug("??????????????????")
return False
if len(pending_messages) < self.min_messages_for_extraction:
return False
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
"""执行表达方式学习主流程。
Args:
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
# 构建可读消息
readable_message, _, _ = await MessageUtils.build_readable_message(
self._messages_cache,
pending_messages,
anonymize=True,
show_lineno=True,
extract_pictures=True,
@@ -293,57 +292,54 @@ class ExpressionLearner:
target_bot_name="SELF",
)
# 准备提示词
prompt_template = prompt_manager.get_prompt("learn_style")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("chat_str", readable_message)
prompt = await prompt_manager.render_prompt(prompt_template)
# 调用 LLM 学习表达方式
try:
generation_result = await express_learn_model.generate_response(
prompt, options=LLMGenerationOptions(temperature=0.3)
prompt,
options=LLMGenerationOptions(temperature=0.3),
)
response = generation_result.response
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错:{e}")
return
logger.error(f"????????????????{e}")
return False
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
jargon_entries: List[Tuple[str, str]]
expressions, jargon_entries = parse_expression_response(response)
# 从缓存中检查 jargon 是否出现在 messages 中
if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner):
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
cached_jargon_entries = self._check_cached_jargons_in_messages(pending_messages, jargon_miner)
if cached_jargon_entries:
existing_contents = {content for content, _ in jargon_entries}
for content, source_id in cached_jargon_entries:
if content not in existing_contents:
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"从缓存中检查到黑话:{content}")
if content in existing_contents:
continue
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"??????????{content}")
# 检查表达方式数量,如果超过 20 个则放弃本次表达学习
if len(expressions) > 20:
logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习")
logger.info(f"?????????? 20 ???????????{len(expressions)}")
expressions = []
# 检查黑话数量,如果超过 30 个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
logger.info(f"???????? 30 ???????????{len(jargon_entries)}")
jargon_entries = []
after_extract_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.after_extract",
session_id=self.session_id,
message_count=len(self._messages_cache),
message_count=len(pending_messages),
expressions=self._serialize_expressions(expressions),
jargon_entries=self._serialize_jargon_entries(jargon_entries),
)
if after_extract_result.aborted:
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
return
logger.info(f"{self.session_id} ?????????? Hook ??")
self._last_processed_index = len(message_cache)
return False
after_extract_kwargs = after_extract_result.kwargs
raw_expressions = after_extract_kwargs.get("expressions")
@@ -353,31 +349,26 @@ class ExpressionLearner:
if raw_jargon_entries is not None:
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
await self._process_jargon_entries(jargon_entries, jargon_miner)
await self._process_jargon_entries(jargon_entries, pending_messages, jargon_miner)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("解析后没有可用的表达方式")
return
logger.info("????????????")
self._last_processed_index = len(message_cache)
return False
logger.info(f"学习的 expressions: {expressions}")
logger.info(f"学习的 jargon_entries: {jargon_entries}")
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
learnt_expressions = self._filter_expressions(expressions)
logger.info(f"???? expressions: {expressions}")
logger.info(f"???? jargon_entries: {jargon_entries}")
learnt_expressions = self._filter_expressions(expressions, pending_messages)
if not learnt_expressions:
logger.info("没有学习到表达风格")
return
logger.info("????????????")
self._last_processed_index = len(message_cache)
return False
# 展示学到的表达方式
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
logger.info(f" {self.session_id} 学习到表达风格:\n{learnt_expressions_str}")
logger.info(f"? {self.session_id} ????????\n{learnt_expressions_str}")
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
before_upsert_result = await self._get_runtime_manager().invoke_hook(
"expression.learn.before_upsert",
@@ -386,19 +377,25 @@ class ExpressionLearner:
style=style,
)
if before_upsert_result.aborted:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
logger.info(f"{self.session_id} ???????? Hook ??: situation={situation!r}")
continue
upsert_kwargs = before_upsert_result.kwargs
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
style = str(upsert_kwargs.get("style", style) or "").strip()
if not situation or not style:
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
logger.info(f"{self.session_id} ???????? Hook ??????")
continue
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======
def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]:
self._last_processed_index = len(message_cache)
return True
def _check_cached_jargons_in_messages(
self,
messages: List["SessionMessage"],
jargon_miner: Optional["JargonMiner"] = None,
) -> List[Tuple[str, str]]:
"""
检查缓存中的 jargon 是否出现在 messages 中
@@ -418,7 +415,7 @@ class ExpressionLearner:
matched_entries: List[Tuple[str, str]] = []
for i, msg in enumerate(self._messages_cache):
for i, msg in enumerate(messages):
# 跳过机器人自己的消息
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
continue
@@ -454,7 +451,10 @@ class ExpressionLearner:
return matched_entries
async def _process_jargon_entries(
self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None
self,
jargon_entries: List[Tuple[str, str]],
messages: List["SessionMessage"],
jargon_miner: Optional["JargonMiner"] = None,
):
"""
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
@@ -463,7 +463,7 @@ class ExpressionLearner:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
jargon_miner: JargonMiner 实例
"""
if not jargon_entries or not self._messages_cache:
if not jargon_entries or not messages:
return
if not jargon_miner:
@@ -497,20 +497,20 @@ class ExpressionLearner:
# build_readable_message 的编号从 1 开始
line_index = int(source_id) - 1
if line_index < 0 or line_index >= len(self._messages_cache):
if line_index < 0 or line_index >= len(messages):
logger.warning(f"黑话条目 source_id 超出范围content={content}, source_id={source_id}")
continue
# 检查是否是机器人自己的消息
target_msg = self._messages_cache[line_index]
target_msg = messages[line_index]
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
logger.info(f"跳过引用机器人自身消息的黑话content={content}, source_id={source_id}")
continue
# 构建上下文段落(取前后各 3 条消息)
start_idx = max(0, line_index - 3)
end_idx = min(len(self._messages_cache), line_index + 4)
context_msgs = self._messages_cache[start_idx:end_idx]
end_idx = min(len(messages), line_index + 4)
context_msgs = messages[start_idx:end_idx]
context_paragraph = "\n".join(
[f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)]
@@ -529,7 +529,11 @@ class ExpressionLearner:
logger.info(f"成功处理 {len(entries)} 个黑话条目")
# ====== 过滤方法 ======
def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]:
def _filter_expressions(
self,
expressions: List[Tuple[str, str, str]],
messages: List["SessionMessage"],
) -> List[Tuple[str, str]]:
"""
过滤表达方式,移除不符合条件的条目
@@ -558,10 +562,10 @@ class ExpressionLearner:
if not source_id_str.isdigit():
continue # 无效的来源行编号,跳过
line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始
if line_index < 0 or line_index >= len(self._messages_cache):
if line_index < 0 or line_index >= len(messages):
continue # 超出范围,跳过
# 当前行的原始消息
current_msg = self._messages_cache[line_index]
current_msg = messages[line_index]
# 过滤掉从 bot 自己发言中提取到的表达方式
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
continue