404 lines
17 KiB
Python
404 lines
17 KiB
Python
from collections import OrderedDict
|
||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||
|
||
import asyncio
|
||
import json
|
||
import random
|
||
|
||
from json_repair import repair_json
|
||
from sqlmodel import select
|
||
|
||
from src.common.data_models.jargon_data_model import MaiJargon
|
||
from src.common.database.database import get_db_session
|
||
from src.common.database.database_model import Jargon
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config, model_config
|
||
from src.llm_models.utils_model import LLMRequest
|
||
from src.prompt.prompt_manager import prompt_manager
|
||
|
||
from .expression_utils import is_single_char_jargon
|
||
|
||
logger = get_logger("jargon")
|
||
|
||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
|
||
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
|
||
|
||
|
||
class JargonEntry(TypedDict):
|
||
content: str
|
||
raw_content: Set[str]
|
||
|
||
|
||
class JargonMeaningEntry(TypedDict):
|
||
content: str
|
||
meaning: str
|
||
|
||
|
||
class JargonMiner:
|
||
def __init__(self, session_id: str, session_name: str) -> None:
|
||
self.session_id = session_id
|
||
self.session_name = session_name
|
||
|
||
# Cache 相关
|
||
self.cache_limit = 50
|
||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||
# 黑话提取锁,防止并发执行
|
||
self._extraction_lock = asyncio.Lock()
|
||
|
||
def get_cached_jargons(self) -> List[str]:
|
||
"""获取缓存中的所有黑话列表"""
|
||
return list(self.cache.keys())
|
||
|
||
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
|
||
"""
|
||
对jargon进行含义推断
|
||
"""
|
||
content = jargon_obj.content
|
||
# 解析raw_content列表
|
||
raw_content_list = []
|
||
if raw_content_str := jargon_obj.raw_content:
|
||
try:
|
||
raw_content_list = json.loads(raw_content_str)
|
||
if not isinstance(raw_content_list, list):
|
||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||
except (json.JSONDecodeError, TypeError):
|
||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||
|
||
if not raw_content_list:
|
||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||
return
|
||
|
||
# 获取当前count和上一次的meaning
|
||
current_count = jargon_obj.count
|
||
previous_meaning = jargon_obj.meaning
|
||
|
||
# 步骤1: 基于raw_content和content推断
|
||
raw_content_text = "\n".join(raw_content_list)
|
||
|
||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||
# 计算要保留的数量(至少保留1个)
|
||
keep_count = max(1, len(raw_content_list) // 2)
|
||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||
logger.info(
|
||
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||
)
|
||
|
||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||
previous_meaning_section = ""
|
||
previous_meaning_instruction = ""
|
||
if current_count in [24, 60, 100] and previous_meaning:
|
||
previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}"
|
||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||
|
||
prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context")
|
||
prompt1_template.add_context("bot_name", global_config.bot.nickname)
|
||
prompt1_template.add_context("content", str(content))
|
||
prompt1_template.add_context("raw_content_list", raw_content_text)
|
||
prompt1_template.add_context("previous_meaning_section", previous_meaning_section)
|
||
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
|
||
prompt1 = await prompt_manager.render_prompt(prompt1_template)
|
||
|
||
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||
if not llm_response_1:
|
||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||
return
|
||
|
||
# 解析推断1结果
|
||
inference1 = self._parse_result(llm_response_1)
|
||
if not inference1:
|
||
logger.warning(f"jargon {content} 推断1解析失败")
|
||
return
|
||
|
||
no_info = inference1.get("no_info", False)
|
||
meaning1: str = inference1.get("meaning", "").strip()
|
||
if no_info or not meaning1:
|
||
logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新")
|
||
# 更新最后一次判定的count值,避免在同一阈值重复尝试
|
||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||
|
||
try:
|
||
self._modify_jargon_entry(jargon_obj)
|
||
except Exception as e:
|
||
logger.error(f"jargon {content} 推断1更新last_inference_count失败: {e}")
|
||
return
|
||
|
||
# 步骤2: 基于content-only进行推断
|
||
prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only")
|
||
prompt2_template.add_context("content", content)
|
||
prompt2 = await prompt_manager.render_prompt(prompt2_template)
|
||
|
||
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||
if not llm_response_2:
|
||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||
return
|
||
|
||
# 解析推断2结果
|
||
inference2 = self._parse_result(llm_response_2)
|
||
if not inference2:
|
||
logger.warning(f"jargon {content} 推断2解析失败")
|
||
return
|
||
|
||
if global_config.debug.show_jargon_prompt:
|
||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||
|
||
# 步骤3: 比较两个推断结果
|
||
prompt3_template = prompt_manager.get_prompt("jargon_compare_inference")
|
||
prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False))
|
||
prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False))
|
||
prompt3 = await prompt_manager.render_prompt(prompt3_template)
|
||
|
||
if global_config.debug.show_jargon_prompt:
|
||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||
|
||
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||
if not llm_response_3:
|
||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||
return
|
||
|
||
comparison_result = self._parse_result(llm_response_3)
|
||
if not comparison_result:
|
||
logger.warning(f"jargon {content} 比较解析失败")
|
||
return
|
||
|
||
is_similar = comparison_result.get("is_similar", False)
|
||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||
|
||
# 更新数据库记录
|
||
jargon_obj.is_jargon = is_jargon
|
||
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
|
||
# 更新最后一次判定的count值,避免重启后重复判定
|
||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||
|
||
# 如果count>=100,标记为完成,不再进行推断
|
||
if (jargon_obj.count or 0) >= 100:
|
||
jargon_obj.is_complete = True
|
||
|
||
try:
|
||
self._modify_jargon_entry(jargon_obj)
|
||
except Exception as e:
|
||
logger.error(f"jargon {content} 推断结果更新失败: {e}")
|
||
logger.debug(
|
||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||
)
|
||
|
||
# 固定输出推断结果,格式化为可读形式
|
||
if is_jargon:
|
||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||
meaning = jargon_obj.meaning or "无详细说明"
|
||
is_global = jargon_obj.is_global # 是否为全局的
|
||
if is_global:
|
||
logger.info(f"[黑话]{content}的含义是 {meaning}")
|
||
else:
|
||
logger.info(f"[{self.session_name}]{content}的含义是 {meaning}")
|
||
else:
|
||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||
logger.info(f"[{self.session_name}]{content} 不是黑话")
|
||
|
||
async def process_extracted_entries(
|
||
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
|
||
):
|
||
"""
|
||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||
|
||
Args:
|
||
entries: 黑话条目列表
|
||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||
"""
|
||
if not entries:
|
||
return
|
||
merged_entries: Dict[str, JargonEntry] = {}
|
||
for entry in entries:
|
||
content = entry["content"].strip()
|
||
|
||
if person_name_filter and person_name_filter(content):
|
||
logger.info(f"条目 '{content}' 包含人物名称,已过滤")
|
||
continue
|
||
raw_list = entry["raw_content"] or set()
|
||
if content in merged_entries:
|
||
merged_entries[content]["raw_content"].update(raw_list)
|
||
else:
|
||
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
|
||
|
||
uniq_entries: List[JargonEntry] = list(merged_entries.values())
|
||
|
||
saved = 0
|
||
updated = 0
|
||
for entry in uniq_entries:
|
||
content = entry["content"]
|
||
raw_content_set = entry["raw_content"]
|
||
try:
|
||
with get_db_session() as session:
|
||
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
|
||
except Exception as e:
|
||
logger.error(f"查询黑话 '{content}' 失败: {e}")
|
||
continue
|
||
# 找匹配项
|
||
matched_jargon: Optional[Jargon] = None
|
||
for item in jargon_items:
|
||
if global_config.expression.all_global_jargon:
|
||
# 开启all_global:所有content匹配的记录都可以
|
||
matched_jargon = item
|
||
break
|
||
else:
|
||
# 检查列表是否包含目标session_id
|
||
if item.session_id_dict:
|
||
try:
|
||
session_id_dict = json.loads(item.session_id_dict)
|
||
if self.session_id in session_id_dict:
|
||
matched_jargon = item
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"解析Jargon id={item.id} session_id_list失败: {e}")
|
||
continue
|
||
if matched_jargon:
|
||
# 已存在记录,更新count和raw_content
|
||
self._update_jargon(matched_jargon, raw_content_set)
|
||
if self._should_infer_meaning(matched_jargon):
|
||
asyncio.create_task(self._infer_meaning_by_id(matched_jargon.id)) # type: ignore
|
||
updated += 1
|
||
else:
|
||
# 没找到匹配记录,创建新记录
|
||
is_global_new = global_config.expression.all_global_jargon
|
||
session_dict_str = json.dumps({self.session_id: 1})
|
||
new_jargon = Jargon(
|
||
content=content,
|
||
raw_content=json.dumps(list(raw_content_set), ensure_ascii=False),
|
||
session_id_dict=session_dict_str,
|
||
is_global=is_global_new,
|
||
count=1,
|
||
meaning="",
|
||
)
|
||
try:
|
||
with get_db_session() as session:
|
||
session.add(new_jargon)
|
||
session.flush()
|
||
saved += 1
|
||
self._add_to_cache(content)
|
||
except Exception as e:
|
||
logger.error(f"保存新黑话 '{content}' 失败: {e}")
|
||
continue
|
||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||
if uniq_entries:
|
||
# 收集所有提取的jargon内容
|
||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||
jargon_str = ",".join(jargon_list)
|
||
logger.info(f"[{self.session_name}]疑似黑话: {jargon_str}")
|
||
|
||
if saved or updated:
|
||
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,session_id={self.session_id}")
|
||
|
||
def _add_to_cache(self, content: str):
|
||
"""将黑话内容添加到缓存,并维护缓存大小"""
|
||
content = content.strip()
|
||
if is_single_char_jargon(content):
|
||
return
|
||
if content in self.cache:
|
||
# 已存在,移动到末尾表示最近使用
|
||
self.cache.move_to_end(content)
|
||
else:
|
||
# 新内容,添加到缓存
|
||
self.cache[content] = None
|
||
# 如果超过限制,移除最旧的项
|
||
if len(self.cache) > self.cache_limit:
|
||
removed_content, _ = self.cache.popitem(last=False)
|
||
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
|
||
|
||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]):
|
||
db_jargon.count += 1
|
||
existing_raw_content: List[str] = []
|
||
if db_jargon.raw_content:
|
||
try:
|
||
existing_raw_content = json.loads(db_jargon.raw_content)
|
||
except Exception:
|
||
existing_raw_content = []
|
||
|
||
# 合并去重
|
||
merged_list = list(set(existing_raw_content).union(raw_content_set))
|
||
db_jargon.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||
session_id_dict: Dict[str, int] = json.loads(db_jargon.session_id_dict)
|
||
session_id_dict[self.session_id] = session_id_dict.get(self.session_id, 0) + 1
|
||
db_jargon.session_id_dict = json.dumps(session_id_dict)
|
||
|
||
# 开启all_global时,确保记录标记为is_global=True
|
||
if global_config.expression.all_global_jargon:
|
||
db_jargon.is_global = True
|
||
|
||
try:
|
||
with get_db_session() as session:
|
||
session.add(db_jargon)
|
||
except Exception as e:
|
||
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
|
||
|
||
def _parse_result(self, response: str) -> Optional[Dict[str, str]]:
|
||
try:
|
||
result = json.loads(response.strip())
|
||
except Exception:
|
||
try:
|
||
repaired = repair_json(response.strip())
|
||
result = json.loads(repaired)
|
||
except Exception as e2:
|
||
logger.error(f"推断结果解析失败: {e2}")
|
||
return None
|
||
if not isinstance(result, dict):
|
||
logger.warning("推断结果格式错误")
|
||
return None
|
||
return result
|
||
|
||
def _modify_jargon_entry(self, jargon_obj: MaiJargon) -> None:
|
||
with get_db_session() as session:
|
||
if not jargon_obj.item_id:
|
||
raise ValueError("jargon_obj must have item_id to update")
|
||
statement = select(Jargon).filter_by(id=jargon_obj.item_id).limit(1)
|
||
if db_record := session.exec(statement).first():
|
||
db_record.is_jargon = jargon_obj.is_jargon
|
||
db_record.meaning = jargon_obj.meaning
|
||
db_record.last_inference_count = jargon_obj.last_inference_count
|
||
db_record.is_complete = jargon_obj.is_complete
|
||
session.add(db_record)
|
||
|
||
def _should_infer_meaning(self, jargon_obj: Jargon) -> bool:
|
||
"""
|
||
判断是否需要进行含义推断
|
||
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
|
||
并且count必须大于last_inference_count,避免重启后重复判定
|
||
如果is_complete为True,不再进行推断
|
||
"""
|
||
# 如果已完成所有推断,不再推断
|
||
if jargon_obj.is_complete:
|
||
return False
|
||
|
||
count = jargon_obj.count or 0
|
||
last_inference = jargon_obj.last_inference_count or 0
|
||
|
||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||
thresholds = [2, 4, 8, 12, 24, 60, 100]
|
||
|
||
if count < thresholds[0]:
|
||
return False
|
||
# 如果count没有超过上次判定值,不需要判定
|
||
if count <= last_inference:
|
||
return False
|
||
|
||
next_threshold = next(
|
||
(threshold for threshold in thresholds if threshold > last_inference),
|
||
None,
|
||
)
|
||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||
return False if next_threshold is None else count >= next_threshold
|
||
|
||
async def _infer_meaning_by_id(self, jargon_id: int):
|
||
jargon_obj: Optional[MaiJargon] = None
|
||
try:
|
||
with get_db_session() as session:
|
||
statement = select(Jargon).filter_by(id=jargon_id).limit(1)
|
||
if db_record := session.exec(statement).first():
|
||
jargon_obj = MaiJargon.from_db_instance(db_record)
|
||
except Exception as e:
|
||
logger.error(f"查询Jargon id={jargon_id}失败: {e}")
|
||
return
|
||
if jargon_obj:
|
||
await self.infer_meaning(jargon_obj)
|