feat:修复表达方式的学习和使用,用subagent使用表达
1
This commit is contained in:
@@ -159,7 +159,8 @@ class ExpressionLearner:
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
self._last_processed_index = 0
|
||||
self.min_messages_for_extraction = 10
|
||||
|
||||
@staticmethod
|
||||
def _get_runtime_manager() -> Any:
|
||||
@@ -265,27 +266,25 @@ class ExpressionLearner:
|
||||
normalized_entries.append((content, source_id))
|
||||
return normalized_entries
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
def get_pending_count(self, message_cache: List["SessionMessage"]) -> int:
|
||||
"""??????????????"""
|
||||
return max(0, len(message_cache) - self._last_processed_index)
|
||||
|
||||
def get_cache_size(self) -> int:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
async def learn(
|
||||
self,
|
||||
message_cache: List["SessionMessage"],
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
) -> bool:
|
||||
"""?????????????????????"""
|
||||
pending_messages = message_cache[self._last_processed_index :]
|
||||
if not pending_messages:
|
||||
logger.debug("??????????????????")
|
||||
return False
|
||||
if len(pending_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None) -> None:
|
||||
"""执行表达方式学习主流程。
|
||||
|
||||
Args:
|
||||
jargon_miner: 可选的黑话学习器实例,用于同步处理黑话候选。
|
||||
"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
|
||||
# 构建可读消息
|
||||
readable_message, _, _ = await MessageUtils.build_readable_message(
|
||||
self._messages_cache,
|
||||
pending_messages,
|
||||
anonymize=True,
|
||||
show_lineno=True,
|
||||
extract_pictures=True,
|
||||
@@ -293,57 +292,54 @@ class ExpressionLearner:
|
||||
target_bot_name="SELF",
|
||||
)
|
||||
|
||||
# 准备提示词
|
||||
prompt_template = prompt_manager.get_prompt("learn_style")
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("chat_str", readable_message)
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
# 调用 LLM 学习表达方式
|
||||
try:
|
||||
generation_result = await express_learn_model.generate_response(
|
||||
prompt, options=LLMGenerationOptions(temperature=0.3)
|
||||
prompt,
|
||||
options=LLMGenerationOptions(temperature=0.3),
|
||||
)
|
||||
response = generation_result.response
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错:{e}")
|
||||
return
|
||||
logger.error(f"????????????????{e}")
|
||||
return False
|
||||
|
||||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
jargon_entries: List[Tuple[str, str]]
|
||||
expressions, jargon_entries = parse_expression_response(response)
|
||||
|
||||
# 从缓存中检查 jargon 是否出现在 messages 中
|
||||
if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner):
|
||||
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
|
||||
cached_jargon_entries = self._check_cached_jargons_in_messages(pending_messages, jargon_miner)
|
||||
if cached_jargon_entries:
|
||||
existing_contents = {content for content, _ in jargon_entries}
|
||||
for content, source_id in cached_jargon_entries:
|
||||
if content not in existing_contents:
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"从缓存中检查到黑话:{content}")
|
||||
if content in existing_contents:
|
||||
continue
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"??????????{content}")
|
||||
|
||||
# 检查表达方式数量,如果超过 20 个则放弃本次表达学习
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
logger.info(f"?????????? 20 ???????????{len(expressions)}")
|
||||
expressions = []
|
||||
|
||||
# 检查黑话数量,如果超过 30 个则放弃本次黑话学习
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
logger.info(f"???????? 30 ???????????{len(jargon_entries)}")
|
||||
jargon_entries = []
|
||||
|
||||
after_extract_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.after_extract",
|
||||
session_id=self.session_id,
|
||||
message_count=len(self._messages_cache),
|
||||
message_count=len(pending_messages),
|
||||
expressions=self._serialize_expressions(expressions),
|
||||
jargon_entries=self._serialize_jargon_entries(jargon_entries),
|
||||
)
|
||||
if after_extract_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式学习结果被 Hook 中止")
|
||||
return
|
||||
logger.info(f"{self.session_id} ?????????? Hook ??")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
after_extract_kwargs = after_extract_result.kwargs
|
||||
raw_expressions = after_extract_kwargs.get("expressions")
|
||||
@@ -353,31 +349,26 @@ class ExpressionLearner:
|
||||
if raw_jargon_entries is not None:
|
||||
jargon_entries = self._deserialize_jargon_entries(raw_jargon_entries)
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
await self._process_jargon_entries(jargon_entries, jargon_miner)
|
||||
await self._process_jargon_entries(jargon_entries, pending_messages, jargon_miner)
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("解析后没有可用的表达方式")
|
||||
return
|
||||
logger.info("????????????")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
logger.info(f"学习的 expressions: {expressions}")
|
||||
logger.info(f"学习的 jargon_entries: {jargon_entries}")
|
||||
|
||||
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
|
||||
learnt_expressions = self._filter_expressions(expressions)
|
||||
logger.info(f"???? expressions: {expressions}")
|
||||
logger.info(f"???? jargon_entries: {jargon_entries}")
|
||||
|
||||
learnt_expressions = self._filter_expressions(expressions, pending_messages)
|
||||
if not learnt_expressions:
|
||||
logger.info("没有学习到表达风格")
|
||||
return
|
||||
logger.info("????????????")
|
||||
self._last_processed_index = len(message_cache)
|
||||
return False
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
|
||||
logger.info(f"在 {self.session_id} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
logger.info(f"? {self.session_id} ????????\n{learnt_expressions_str}")
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
before_upsert_result = await self._get_runtime_manager().invoke_hook(
|
||||
"expression.learn.before_upsert",
|
||||
@@ -386,19 +377,25 @@ class ExpressionLearner:
|
||||
style=style,
|
||||
)
|
||||
if before_upsert_result.aborted:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 跳过: situation={situation!r}")
|
||||
logger.info(f"{self.session_id} ???????? Hook ??: situation={situation!r}")
|
||||
continue
|
||||
|
||||
upsert_kwargs = before_upsert_result.kwargs
|
||||
situation = str(upsert_kwargs.get("situation", situation) or "").strip()
|
||||
style = str(upsert_kwargs.get("style", style) or "").strip()
|
||||
if not situation or not style:
|
||||
logger.info(f"{self.session_id} 的表达方式写入被 Hook 清空,已跳过")
|
||||
logger.info(f"{self.session_id} ???????? Hook ??????")
|
||||
continue
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]:
|
||||
self._last_processed_index = len(message_cache)
|
||||
return True
|
||||
|
||||
def _check_cached_jargons_in_messages(
|
||||
self,
|
||||
messages: List["SessionMessage"],
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
检查缓存中的 jargon 是否出现在 messages 中
|
||||
|
||||
@@ -418,7 +415,7 @@ class ExpressionLearner:
|
||||
|
||||
matched_entries: List[Tuple[str, str]] = []
|
||||
|
||||
for i, msg in enumerate(self._messages_cache):
|
||||
for i, msg in enumerate(messages):
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
|
||||
continue
|
||||
@@ -454,7 +451,10 @@ class ExpressionLearner:
|
||||
return matched_entries
|
||||
|
||||
async def _process_jargon_entries(
|
||||
self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None
|
||||
self,
|
||||
jargon_entries: List[Tuple[str, str]],
|
||||
messages: List["SessionMessage"],
|
||||
jargon_miner: Optional["JargonMiner"] = None,
|
||||
):
|
||||
"""
|
||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||
@@ -463,7 +463,7 @@ class ExpressionLearner:
|
||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||
jargon_miner: JargonMiner 实例
|
||||
"""
|
||||
if not jargon_entries or not self._messages_cache:
|
||||
if not jargon_entries or not messages:
|
||||
return
|
||||
|
||||
if not jargon_miner:
|
||||
@@ -497,20 +497,20 @@ class ExpressionLearner:
|
||||
|
||||
# build_readable_message 的编号从 1 开始
|
||||
line_index = int(source_id) - 1
|
||||
if line_index < 0 or line_index >= len(self._messages_cache):
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
logger.warning(f"黑话条目 source_id 超出范围:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = self._messages_cache[line_index]
|
||||
target_msg = messages[line_index]
|
||||
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
# 构建上下文段落(取前后各 3 条消息)
|
||||
start_idx = max(0, line_index - 3)
|
||||
end_idx = min(len(self._messages_cache), line_index + 4)
|
||||
context_msgs = self._messages_cache[start_idx:end_idx]
|
||||
end_idx = min(len(messages), line_index + 4)
|
||||
context_msgs = messages[start_idx:end_idx]
|
||||
|
||||
context_paragraph = "\n".join(
|
||||
[f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)]
|
||||
@@ -529,7 +529,11 @@ class ExpressionLearner:
|
||||
logger.info(f"成功处理 {len(entries)} 个黑话条目")
|
||||
|
||||
# ====== 过滤方法 ======
|
||||
def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]:
|
||||
def _filter_expressions(
|
||||
self,
|
||||
expressions: List[Tuple[str, str, str]],
|
||||
messages: List["SessionMessage"],
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
过滤表达方式,移除不符合条件的条目
|
||||
|
||||
@@ -558,10 +562,10 @@ class ExpressionLearner:
|
||||
if not source_id_str.isdigit():
|
||||
continue # 无效的来源行编号,跳过
|
||||
line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(self._messages_cache):
|
||||
if line_index < 0 or line_index >= len(messages):
|
||||
continue # 超出范围,跳过
|
||||
# 当前行的原始消息
|
||||
current_msg = self._messages_cache[line_index]
|
||||
current_msg = messages[line_index]
|
||||
# 过滤掉从 bot 自己发言中提取到的表达方式
|
||||
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user