Files
mai-bot/src/learners/expression_learner.py

505 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from src.chat.utils.utils import is_bot_self
from src.common.utils.utils_message import MessageUtils
from .expression_utils import check_expression_suitability, parse_expression_response
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from .jargon_miner import JargonMiner, JargonEntry
logger = get_logger("expressor")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner")
summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary")
check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check")
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
def get_cache_size(self) -> int:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
# 构建可读消息
readable_message, _, _ = await MessageUtils.build_readable_message(
self._messages_cache,
anonymize=True,
show_lineno=True,
extract_pictures=True,
replace_bot_name=True,
target_bot_name="SELF",
)
# 准备提示词
prompt_template = prompt_manager.get_prompt("learn_style")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("chat_str", readable_message)
prompt = await prompt_manager.render_prompt(prompt_template)
# 调用 LLM 学习表达方式
try:
response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错:{e}")
return
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = parse_expression_response(response)
# 从缓存中检查 jargon 是否出现在 messages 中
if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner):
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
existing_contents = {content for content, _ in jargon_entries}
for content, source_id in cached_jargon_entries:
if content not in existing_contents:
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"从缓存中检查到黑话:{content}")
# 检查表达方式数量,如果超过 20 个则放弃本次表达学习
if len(expressions) > 20:
logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习")
expressions = []
# 检查黑话数量,如果超过 30 个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
await self._process_jargon_entries(jargon_entries, jargon_miner)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("解析后没有可用的表达方式")
return
logger.info(f"学习的 expressions: {expressions}")
logger.info(f"学习的 jargon_entries: {jargon_entries}")
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
learnt_expressions = self._filter_expressions(expressions)
if not learnt_expressions:
logger.info("没有学习到表达风格")
return
# 展示学到的表达方式
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
logger.info(f"{self.session_id} 学习到表达风格:\n{learnt_expressions_str}")
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======
def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]:
"""
检查缓存中的 jargon 是否出现在 messages 中
Args:
jargon_miner: JargonMiner 实例,用于获取缓存的黑话
Returns:
List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id)
"""
if not jargon_miner:
return []
# 获取缓存的所有 jargon 实例
cached_jargons = jargon_miner.get_cached_jargons()
if not cached_jargons:
return []
matched_entries: List[Tuple[str, str]] = []
for i, msg in enumerate(self._messages_cache):
# 跳过机器人自己的消息
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
continue
# 获取消息文本
msg_text = (msg.processed_plain_text or "").strip()
if not msg_text:
continue
# 检查每个缓存中的 jargon 是否出现在消息文本中
for jargon in cached_jargons:
if not jargon or not jargon.strip():
continue
jargon_content = jargon.strip()
# 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑)
pattern = re.escape(jargon_content)
# 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界
if re.search(r"[\u4e00-\u9fff]", jargon_content):
# 包含中文,使用更宽松的匹配
search_pattern = pattern
else:
# 纯英文/数字,使用单词边界
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, msg_text, re.IGNORECASE):
# 找到匹配构建条目source_id 从 1 开始,因为 build_readable_message 的编号从 1 开始)
source_id = str(i + 1)
matched_entries.append((jargon_content, source_id))
return matched_entries
async def _process_jargon_entries(
self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None
):
"""
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
jargon_miner: JargonMiner 实例
"""
if not jargon_entries or not self._messages_cache:
return
if not jargon_miner:
logger.warning("缺少 JargonMiner 实例,无法处理黑话条目")
return
# 构建黑话条目格式
entries: List["JargonEntry"] = []
for content, source_id in jargon_entries:
content = content.strip()
if not content:
continue
# 过滤掉包含 SELF 的黑话,不学习
if "SELF" in content:
logger.info(f"跳过包含 SELF 的黑话:{content}")
continue
# TODO: 多平台兼容
# 检查是否包含机器人名称
bot_nickname = global_config.bot.nickname
if bot_nickname and bot_nickname in content:
logger.info(f"跳过包含机器人昵称的黑话:{content}")
continue
# 解析 source_id
if not source_id.isdigit():
logger.warning(f"黑话条目 source_id 无效content={content}, source_id={source_id}")
continue
# build_readable_message 的编号从 1 开始
line_index = int(source_id) - 1
if line_index < 0 or line_index >= len(self._messages_cache):
logger.warning(f"黑话条目 source_id 超出范围content={content}, source_id={source_id}")
continue
# 检查是否是机器人自己的消息
target_msg = self._messages_cache[line_index]
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
logger.info(f"跳过引用机器人自身消息的黑话content={content}, source_id={source_id}")
continue
# 构建上下文段落(取前后各 3 条消息)
start_idx = max(0, line_index - 3)
end_idx = min(len(self._messages_cache), line_index + 4)
context_msgs = self._messages_cache[start_idx:end_idx]
context_paragraph = "\n".join(
[f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)]
)
if not context_paragraph:
logger.warning(f"黑话条目上下文为空content={content}, source_id={source_id}")
continue
entries.append({"content": content, "raw_content": {context_paragraph}}) # type: ignore
if not entries:
return
await jargon_miner.process_extracted_entries(entries)
logger.info(f"成功处理 {len(entries)} 个黑话条目")
# ====== 过滤方法 ======
def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]:
"""
过滤表达方式,移除不符合条件的条目
Args:
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
Returns:
过滤后的表达方式列表,每个元素是 (situation, style)
"""
filtered_expressions: List[Tuple[str, str]] = []
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
# TODO: 完善这里的机器人名称检测逻辑(考虑别名、不同平台的名称等)
banned_names: set[str] = set()
bot_nickname = global_config.bot.nickname
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
if alias_stripped := alias.strip():
banned_names.add(alias_stripped)
banned_casefold = {name.casefold() for name in banned_names if name}
for situation, style, source_id in expressions:
source_id_str = source_id.strip()
if not source_id_str.isdigit():
continue # 无效的来源行编号,跳过
line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始
if line_index < 0 or line_index >= len(self._messages_cache):
continue # 超出范围,跳过
# 当前行的原始消息
current_msg = self._messages_cache[line_index]
# 过滤掉从 bot 自己发言中提取到的表达方式
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
continue
# 过滤掉无上下文的表达方式
context = (current_msg.processed_plain_text or "").strip()
if not context:
continue
# 过滤掉包含 SELF 的内容(不学习)
if "SELF" in situation or "SELF" in style or "SELF" in context:
logger.info(f"跳过包含 SELF 的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
# 过滤掉 style 与机器人名称/昵称重复的表达
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() in banned_casefold:
logger.debug(
f"跳过 style 与机器人名称重复的表达方式situation={situation}, style={style}, source_id={source_id}"
)
continue
# 过滤掉包含 "[表情" 的内容
if "[表情包" in situation or "[表情包" in style or "[表情包" in context:
logger.info(f"跳过包含表情标记的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
# 过滤掉包含 "[图片" 的内容
if "[图片" in situation or "[图片" in style or "[图片" in context:
logger.info(f"跳过包含图片标记的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
filtered_expressions.append((situation, style))
return filtered_expressions
# ====== DB 操作相关 ======
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
"""将表达方式写入数据库,存在时更新,不存在时新增。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
use_llm_summary = similarity < 1.0
await self._update_existing_expression(expr, situation, use_llm_summary=use_llm_summary)
return
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
def _create_expression(self, situation: str, style: str) -> None:
"""创建新的表达方式记录。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
content_list = [situation]
try:
with get_db_session() as db:
new_expr = Expression(
situation=situation,
style=style,
content_list=json.dumps(content_list),
count=1,
session_id=self.session_id,
last_active_time=datetime.now(),
)
db.add(new_expr)
db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")
async def _update_existing_expression(self, expr: "MaiExpression", situation: str, use_llm_summary: bool = True):
expr.content.append(situation)
expr.count += 1
expr.checked = False # count 增加时重置 checked 为 False
expr.last_active_time = datetime.now()
if use_llm_summary:
# 相似匹配时,使用 LLM 重新组合 situation
new_situation = await self._compose_situation_text(expr.content)
if new_situation:
expr.situation = new_situation
try:
with get_db_session() as session:
if expr.item_id is None:
raise ValueError("表达方式对象缺少 item_id无法更新数据库记录")
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
if db_expr := session.exec(statement).first():
db_expr.content_list = json.dumps(expr.content)
db_expr.count = expr.count
db_expr.checked = expr.checked
db_expr.last_active_time = expr.last_active_time
db_expr.situation = expr.situation # 更新 situation
session.add(db_expr)
else:
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新表达方式失败: {e}")
# count 增加后,立即进行一次检查
await self._check_expression(expr)
# ====== 概括方法 ======
async def _compose_situation_text(self, content_list: List[str]) -> Optional[str]:
texts = [c.strip() for c in content_list if c.strip()]
if not texts:
return None
description = "\n".join(f"- {s}" for s in texts[-10:]) # 只取最近10条进行概括
prompt = (
"请阅读以下多个聊天情境描述并将它们概括成一句简短的话长度不超过20个字保留共同特点\n"
f"{description}\n"
"只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
if summary := summary.strip():
return summary
except Exception as e:
logger.error(f"使用 LLM 生成表达方式概括失败: {e}")
return None
async def _check_expression(self, expr: "MaiExpression"):
"""
检查表达方式(在 count 增加后调用)
Args:
expr (MaiExpression): 要检查的表达方式对象
"""
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自我反思功能未启用,跳过检查")
return
suitable, reason, error = await check_expression_suitability(expr.situation, expr.style)
if error:
logger.error(f"检查表达方式时发生错误: {error}")
return
expr.checked = True
expr.rejected = not suitable
try:
with get_db_session() as session:
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
if db_expr := session.exec(statement).first():
db_expr.checked = expr.checked
db_expr.rejected = expr.rejected
session.add(db_expr)
else:
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新检查结果")
except Exception as e:
logger.error(f"更新表达方式检查结果失败: {e}")
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式检查完成 [ID: {expr.item_id}] - {status} | "
f"Situation: {expr.situation[:30]}... | "
f"Style: {expr.style[:30]}... | "
f"Reason: {reason[:50] if reason else ''}..."
)
def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]:
"""在数据库中查找相似的表达方式。
Args:
situation: 当前待匹配的情景描述。
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
Returns:
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
``(表达方式对象, 相似度)``;否则返回 ``None``。
"""
try:
with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all()
best_match: Optional[MaiExpression] = None
best_similarity = 0.0
for db_expression in expressions:
expression = MaiExpression.from_db_instance(db_expression)
candidate_situations = [expression.situation, *expression.content]
for candidate_situation in candidate_situations:
normalized_candidate_situation = candidate_situation.strip()
if not normalized_candidate_situation:
continue
similarity = difflib.SequenceMatcher(
None,
situation,
normalized_candidate_situation,
).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expression
if best_match:
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
return best_match, best_similarity
except Exception as e:
logger.error(f"查找相似表达方式失败: {e}")
return None