Files
mai-bot/src/learners/jargon_miner.py

404 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
from .expression_utils import is_single_char_jargon
logger = get_logger("jargon")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
class JargonEntry(TypedDict):
content: str
raw_content: Set[str]
class JargonMeaningEntry(TypedDict):
content: str
meaning: str
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
self.session_id = session_id
self.session_name = session_name
# Cache 相关
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""
content = jargon_obj.content
# 解析raw_content列表
raw_content_list = []
if raw_content_str := jargon_obj.raw_content:
try:
raw_content_list = json.loads(raw_content_str)
if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError):
raw_content_list = [raw_content_str] if raw_content_str else []
if not raw_content_list:
logger.warning(f"jargon {content} 没有raw_content跳过推断")
return
# 获取当前count和上一次的meaning
current_count = jargon_obj.count
previous_meaning = jargon_obj.meaning
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
if current_count in [24, 60, 100] and previous_meaning:
previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}"
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context")
prompt1_template.add_context("bot_name", global_config.bot.nickname)
prompt1_template.add_context("content", str(content))
prompt1_template.add_context("raw_content_list", raw_content_text)
prompt1_template.add_context("previous_meaning_section", previous_meaning_section)
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
prompt1 = await prompt_manager.render_prompt(prompt1_template)
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
if not llm_response_1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
# 解析推断1结果
inference1 = self._parse_result(llm_response_1)
if not inference1:
logger.warning(f"jargon {content} 推断1解析失败")
return
no_info = inference1.get("no_info", False)
meaning1: str = inference1.get("meaning", "").strip()
if no_info or not meaning1:
logger.info(f"jargon {content} 推断1表示信息不足无法推断放弃本次推断待下次更新")
# 更新最后一次判定的count值避免在同一阈值重复尝试
jargon_obj.last_inference_count = jargon_obj.count or 0
try:
self._modify_jargon_entry(jargon_obj)
except Exception as e:
logger.error(f"jargon {content} 推断1更新last_inference_count失败: {e}")
return
# 步骤2: 基于content-only进行推断
prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only")
prompt2_template.add_context("content", content)
prompt2 = await prompt_manager.render_prompt(prompt2_template)
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
if not llm_response_2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
# 解析推断2结果
inference2 = self._parse_result(llm_response_2)
if not inference2:
logger.warning(f"jargon {content} 推断2解析失败")
return
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
# 步骤3: 比较两个推断结果
prompt3_template = prompt_manager.get_prompt("jargon_compare_inference")
prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False))
prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False))
prompt3 = await prompt_manager.render_prompt(prompt3_template)
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
if not llm_response_3:
logger.warning(f"jargon {content} 比较失败:无响应")
return
comparison_result = self._parse_result(llm_response_3)
if not comparison_result:
logger.warning(f"jargon {content} 比较解析失败")
return
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
try:
self._modify_jargon_entry(jargon_obj)
except Exception as e:
logger.error(f"jargon {content} 推断结果更新失败: {e}")
logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式
if is_jargon:
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
meaning = jargon_obj.meaning or "无详细说明"
is_global = jargon_obj.is_global # 是否为全局的
if is_global:
logger.info(f"[黑话]{content}的含义是 {meaning}")
else:
logger.info(f"[{self.session_name}]{content}的含义是 {meaning}")
else:
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
logger.info(f"[{self.session_name}]{content} 不是黑话")
async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
):
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
Args:
entries: 黑话条目列表
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
"""
if not entries:
return
merged_entries: Dict[str, JargonEntry] = {}
for entry in entries:
content = entry["content"].strip()
if person_name_filter and person_name_filter(content):
logger.info(f"条目 '{content}' 包含人物名称,已过滤")
continue
raw_list = entry["raw_content"] or set()
if content in merged_entries:
merged_entries[content]["raw_content"].update(raw_list)
else:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_set = entry["raw_content"]
try:
with get_db_session() as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}")
continue
# 找匹配项
matched_jargon: Optional[Jargon] = None
for item in jargon_items:
if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以
matched_jargon = item
break
else:
# 检查列表是否包含目标session_id
if item.session_id_dict:
try:
session_id_dict = json.loads(item.session_id_dict)
if self.session_id in session_id_dict:
matched_jargon = item
break
except Exception as e:
logger.error(f"解析Jargon id={item.id} session_id_list失败: {e}")
continue
if matched_jargon:
# 已存在记录更新count和raw_content
self._update_jargon(matched_jargon, raw_content_set)
if self._should_infer_meaning(matched_jargon):
asyncio.create_task(self._infer_meaning_by_id(matched_jargon.id)) # type: ignore
updated += 1
else:
# 没找到匹配记录,创建新记录
is_global_new = global_config.expression.all_global_jargon
session_dict_str = json.dumps({self.session_id: 1})
new_jargon = Jargon(
content=content,
raw_content=json.dumps(list(raw_content_set), ensure_ascii=False),
session_id_dict=session_dict_str,
is_global=is_global_new,
count=1,
meaning="",
)
try:
with get_db_session() as session:
session.add(new_jargon)
session.flush()
saved += 1
self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
logger.info(f"[{self.session_name}]疑似黑话: {jargon_str}")
if saved or updated:
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated}session_id={self.session_id}")
def _add_to_cache(self, content: str):
"""将黑话内容添加到缓存,并维护缓存大小"""
content = content.strip()
if is_single_char_jargon(content):
return
if content in self.cache:
# 已存在,移动到末尾表示最近使用
self.cache.move_to_end(content)
else:
# 新内容,添加到缓存
self.cache[content] = None
# 如果超过限制,移除最旧的项
if len(self.cache) > self.cache_limit:
removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
db_jargon.count += 1
existing_raw_content: List[str] = []
if db_jargon.raw_content:
try:
existing_raw_content = json.loads(db_jargon.raw_content)
except Exception:
existing_raw_content = []
# 合并去重
merged_list = list(set(existing_raw_content).union(raw_content_set))
db_jargon.raw_content = json.dumps(merged_list, ensure_ascii=False)
session_id_dict: Dict[str, int] = json.loads(db_jargon.session_id_dict)
session_id_dict[self.session_id] = session_id_dict.get(self.session_id, 0) + 1
db_jargon.session_id_dict = json.dumps(session_id_dict)
# 开启all_global时确保记录标记为is_global=True
if global_config.expression.all_global_jargon:
db_jargon.is_global = True
try:
with get_db_session() as session:
session.add(db_jargon)
except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
def _parse_result(self, response: str) -> Optional[Dict[str, str]]:
try:
result = json.loads(response.strip())
except Exception:
try:
repaired = repair_json(response.strip())
result = json.loads(repaired)
except Exception as e2:
logger.error(f"推断结果解析失败: {e2}")
return None
if not isinstance(result, dict):
logger.warning("推断结果格式错误")
return None
return result
def _modify_jargon_entry(self, jargon_obj: MaiJargon) -> None:
with get_db_session() as session:
if not jargon_obj.item_id:
raise ValueError("jargon_obj must have item_id to update")
statement = select(Jargon).filter_by(id=jargon_obj.item_id).limit(1)
if db_record := session.exec(statement).first():
db_record.is_jargon = jargon_obj.is_jargon
db_record.meaning = jargon_obj.meaning
db_record.last_inference_count = jargon_obj.last_inference_count
db_record.is_complete = jargon_obj.is_complete
session.add(db_record)
def _should_infer_meaning(self, jargon_obj: Jargon) -> bool:
"""
判断是否需要进行含义推断
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
并且count必须大于last_inference_count避免重启后重复判定
如果is_complete为True不再进行推断
"""
# 如果已完成所有推断,不再推断
if jargon_obj.is_complete:
return False
count = jargon_obj.count or 0
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [2, 4, 8, 12, 24, 60, 100]
if count < thresholds[0]:
return False
# 如果count没有超过上次判定值不需要判定
if count <= last_inference:
return False
next_threshold = next(
(threshold for threshold in thresholds if threshold > last_inference),
None,
)
# 如果没有找到下一个阈值说明已经超过100不应该再推断
return False if next_threshold is None else count >= next_threshold
async def _infer_meaning_by_id(self, jargon_id: int):
jargon_obj: Optional[MaiJargon] = None
try:
with get_db_session() as session:
statement = select(Jargon).filter_by(id=jargon_id).limit(1)
if db_record := session.exec(statement).first():
jargon_obj = MaiJargon.from_db_instance(db_record)
except Exception as e:
logger.error(f"查询Jargon id={jargon_id}失败: {e}")
return
if jargon_obj:
await self.infer_meaning(jargon_obj)