merge: sync upstream/r-dev and resolve real conflicts

This commit is contained in:
A-Dawn
2026-03-24 15:36:26 +08:00
114 changed files with 15841 additions and 5236 deletions

View File

@@ -0,0 +1,250 @@
"""
表达方式自动检查定时任务
功能:
1. 定期随机选取指定数量的表达方式
2. 使用 LLM 进行评估
3. 通过评估的rejected=0, checked=1
4. 未通过评估的rejected=1, checked=1
"""
import asyncio
import json
import random
from typing import List
from sqlmodel import select
from src.learners.expression_review_store import get_review_state, set_review_state
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.config.config import global_config
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask
logger = get_logger("expression_auto_check_task")
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
# 基础评估标准
base_criteria = [
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
"允许部分语法错误或口头化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称",
]
# 从配置中获取额外的自定义标准
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
# 合并所有评估标准
all_criteria = base_criteria.copy()
if custom_criteria:
all_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景:{situation}
表达方式或言语风格:{style}
请从以下方面进行评估:
{criteria_list}
请以JSON格式输出评估结果
{{
"suitable": true/false,
"reason": "评估理由(如果不合适,请说明原因)"
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因。
请严格按照JSON格式输出不要包含其他内容。"""
return prompt
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
class ExpressionAutoCheckTask(AsyncTask):
"""表达方式自动检查定时任务"""
def __init__(self):
# 从配置中获取检查间隔和一次检查数量
check_interval = global_config.expression.expression_auto_check_interval
super().__init__(
task_name="Expression Auto Check Task",
wait_before_start=60, # 启动后等待60秒再开始第一次检查
run_interval=check_interval,
)
async def _select_expressions(self, count: int) -> List[Expression]:
"""
随机选择指定数量的未检查表达方式
Args:
count: 需要选择的数量
Returns:
选中的表达方式列表
"""
try:
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
with get_db_session(auto_commit=False) as session:
statement = select(Expression)
all_expressions = session.exec(statement).all()
unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]]
if not unevaluated_expressions:
logger.info("没有未检查的表达方式")
return []
# 随机选择指定数量
selected_count = min(count, len(unevaluated_expressions))
selected = random.sample(unevaluated_expressions, selected_count)
logger.info(f"{len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count}")
return selected
except Exception as e:
logger.error(f"选择表达方式时出错: {e}")
return []
async def _evaluate_expression(self, expression: Expression) -> bool:
"""
评估单个表达方式
Args:
expression: 要评估的表达方式
Returns:
True表示通过False表示不通过
"""
suitable, reason, error = await single_expression_check(
expression.situation,
expression.style,
)
# 更新数据库
try:
set_review_state(expression.id, True, not suitable, "ai")
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式评估完成 [ID: {expression.id}] - {status} | "
f"Situation: {expression.situation}... | "
f"Style: {expression.style}... | "
f"Reason: {reason[:50]}..."
)
if error:
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
return suitable
except Exception as e:
logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}")
return False
async def run(self):
"""执行检查任务"""
try:
# 检查是否启用自动检查
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过本次执行")
return
check_count = global_config.expression.expression_auto_check_count
if check_count <= 0:
logger.warning(f"检查数量配置无效: {check_count},跳过本次执行")
return
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count}")
# 选择要检查的表达方式
expressions = await self._select_expressions(check_count)
if not expressions:
logger.info("没有需要检查的表达方式")
return
# 逐个评估
passed_count = 0
failed_count = 0
for i, expression in enumerate(expressions, 1):
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
if await self._evaluate_expression(expression):
passed_count += 1
else:
failed_count += 1
# 避免请求过快
await asyncio.sleep(0.3)
logger.info(
f"表达方式自动检查完成: 总计 {len(expressions)} 条,通过 {passed_count} 条,不通过 {failed_count}"
)
except Exception as e:
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)

View File

@@ -0,0 +1,504 @@
from datetime import datetime
from sqlmodel import select
from typing import TYPE_CHECKING, List, Optional, Tuple
import asyncio
import difflib
import json
import re
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.database.database import get_db_session
from src.common.data_models.expression_data_model import MaiExpression
from src.chat.utils.utils import is_bot_self
from src.common.utils.utils_message import MessageUtils
from .expression_utils import check_expression_suitability, parse_expression_response
if TYPE_CHECKING:
from src.chat.message_receive.message import SessionMessage
from .jargon_miner import JargonMiner, JargonEntry
logger = get_logger("expressor")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner")
summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary")
check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check")
class ExpressionLearner:
def __init__(self, session_id: str) -> None:
self.session_id = session_id
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()
# 消息缓存
self._messages_cache: List["SessionMessage"] = []
def add_messages(self, messages: List["SessionMessage"]) -> None:
"""添加消息到缓存"""
self._messages_cache.extend(messages)
def get_cache_size(self) -> int:
"""获取当前消息缓存的大小"""
return len(self._messages_cache)
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
"""学习主流程"""
if not self._messages_cache:
logger.debug("没有消息可供学习,跳过学习过程")
return
# 构建可读消息
readable_message, _, _ = await MessageUtils.build_readable_message(
self._messages_cache,
anonymize=True,
show_lineno=True,
extract_pictures=True,
replace_bot_name=True,
target_bot_name="SELF",
)
# 准备提示词
prompt_template = prompt_manager.get_prompt("learn_style")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("chat_str", readable_message)
prompt = await prompt_manager.render_prompt(prompt_template)
# 调用 LLM 学习表达方式
try:
response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错:{e}")
return
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = parse_expression_response(response)
# 从缓存中检查 jargon 是否出现在 messages 中
if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner):
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
existing_contents = {content for content, _ in jargon_entries}
for content, source_id in cached_jargon_entries:
if content not in existing_contents:
jargon_entries.append((content, source_id))
existing_contents.add(content)
logger.info(f"从缓存中检查到黑话:{content}")
# 检查表达方式数量,如果超过 20 个则放弃本次表达学习
if len(expressions) > 20:
logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习")
expressions = []
# 检查黑话数量,如果超过 30 个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
# TODO: 检测是否开启了
if jargon_entries:
await self._process_jargon_entries(jargon_entries, jargon_miner)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("解析后没有可用的表达方式")
return
logger.info(f"学习的 expressions: {expressions}")
logger.info(f"学习的 jargon_entries: {jargon_entries}")
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
learnt_expressions = self._filter_expressions(expressions)
if not learnt_expressions:
logger.info("没有学习到表达风格")
return
# 展示学到的表达方式
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
logger.info(f"{self.session_id} 学习到表达风格:\n{learnt_expressions_str}")
# 存储到数据库 Expression 表
for situation, style in learnt_expressions:
await self._upsert_expression_to_db(situation, style)
# ====== 黑话相关 ======
def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]:
"""
检查缓存中的 jargon 是否出现在 messages 中
Args:
jargon_miner: JargonMiner 实例,用于获取缓存的黑话
Returns:
List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id)
"""
if not jargon_miner:
return []
# 获取缓存的所有 jargon 实例
cached_jargons = jargon_miner.get_cached_jargons()
if not cached_jargons:
return []
matched_entries: List[Tuple[str, str]] = []
for i, msg in enumerate(self._messages_cache):
# 跳过机器人自己的消息
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
continue
# 获取消息文本
msg_text = (msg.processed_plain_text or "").strip()
if not msg_text:
continue
# 检查每个缓存中的 jargon 是否出现在消息文本中
for jargon in cached_jargons:
if not jargon or not jargon.strip():
continue
jargon_content = jargon.strip()
# 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑)
pattern = re.escape(jargon_content)
# 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界
if re.search(r"[\u4e00-\u9fff]", jargon_content):
# 包含中文,使用更宽松的匹配
search_pattern = pattern
else:
# 纯英文/数字,使用单词边界
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, msg_text, re.IGNORECASE):
# 找到匹配构建条目source_id 从 1 开始,因为 build_readable_message 的编号从 1 开始)
source_id = str(i + 1)
matched_entries.append((jargon_content, source_id))
return matched_entries
async def _process_jargon_entries(
self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None
):
"""
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
jargon_miner: JargonMiner 实例
"""
if not jargon_entries or not self._messages_cache:
return
if not jargon_miner:
logger.warning("缺少 JargonMiner 实例,无法处理黑话条目")
return
# 构建黑话条目格式
entries: List["JargonEntry"] = []
for content, source_id in jargon_entries:
content = content.strip()
if not content:
continue
# 过滤掉包含 SELF 的黑话,不学习
if "SELF" in content:
logger.info(f"跳过包含 SELF 的黑话:{content}")
continue
# TODO: 多平台兼容
# 检查是否包含机器人名称
bot_nickname = global_config.bot.nickname
if bot_nickname and bot_nickname in content:
logger.info(f"跳过包含机器人昵称的黑话:{content}")
continue
# 解析 source_id
if not source_id.isdigit():
logger.warning(f"黑话条目 source_id 无效content={content}, source_id={source_id}")
continue
# build_readable_message 的编号从 1 开始
line_index = int(source_id) - 1
if line_index < 0 or line_index >= len(self._messages_cache):
logger.warning(f"黑话条目 source_id 超出范围content={content}, source_id={source_id}")
continue
# 检查是否是机器人自己的消息
target_msg = self._messages_cache[line_index]
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
logger.info(f"跳过引用机器人自身消息的黑话content={content}, source_id={source_id}")
continue
# 构建上下文段落(取前后各 3 条消息)
start_idx = max(0, line_index - 3)
end_idx = min(len(self._messages_cache), line_index + 4)
context_msgs = self._messages_cache[start_idx:end_idx]
context_paragraph = "\n".join(
[f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)]
)
if not context_paragraph:
logger.warning(f"黑话条目上下文为空content={content}, source_id={source_id}")
continue
entries.append({"content": content, "raw_content": {context_paragraph}}) # type: ignore
if not entries:
return
await jargon_miner.process_extracted_entries(entries)
logger.info(f"成功处理 {len(entries)} 个黑话条目")
# ====== 过滤方法 ======
def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]:
"""
过滤表达方式,移除不符合条件的条目
Args:
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
Returns:
过滤后的表达方式列表,每个元素是 (situation, style)
"""
filtered_expressions: List[Tuple[str, str]] = []
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
# TODO: 完善这里的机器人名称检测逻辑(考虑别名、不同平台的名称等)
banned_names: set[str] = set()
bot_nickname = global_config.bot.nickname
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
if alias_stripped := alias.strip():
banned_names.add(alias_stripped)
banned_casefold = {name.casefold() for name in banned_names if name}
for situation, style, source_id in expressions:
source_id_str = source_id.strip()
if not source_id_str.isdigit():
continue # 无效的来源行编号,跳过
line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始
if line_index < 0 or line_index >= len(self._messages_cache):
continue # 超出范围,跳过
# 当前行的原始消息
current_msg = self._messages_cache[line_index]
# 过滤掉从 bot 自己发言中提取到的表达方式
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
continue
# 过滤掉无上下文的表达方式
context = (current_msg.processed_plain_text or "").strip()
if not context:
continue
# 过滤掉包含 SELF 的内容(不学习)
if "SELF" in situation or "SELF" in style or "SELF" in context:
logger.info(f"跳过包含 SELF 的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
# 过滤掉 style 与机器人名称/昵称重复的表达
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() in banned_casefold:
logger.debug(
f"跳过 style 与机器人名称重复的表达方式situation={situation}, style={style}, source_id={source_id}"
)
continue
# 过滤掉包含 "[表情" 的内容
if "[表情包" in situation or "[表情包" in style or "[表情包" in context:
logger.info(f"跳过包含表情标记的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
# 过滤掉包含 "[图片" 的内容
if "[图片" in situation or "[图片" in style or "[图片" in context:
logger.info(f"跳过包含图片标记的表达方式situation={situation}, style={style}, source_id={source_id}")
continue
filtered_expressions.append((situation, style))
return filtered_expressions
# ====== DB 操作相关 ======
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
"""将表达方式写入数据库,存在时更新,不存在时新增。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
expr, similarity = self._find_similar_expression(situation) or (None, 0)
if expr:
# 根据相似度决定是否使用 LLM 总结
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
use_llm_summary = similarity < 1.0
await self._update_existing_expression(expr, situation, use_llm_summary=use_llm_summary)
return
# 没有找到匹配的记录,创建新记录
self._create_expression(situation, style)
def _create_expression(self, situation: str, style: str) -> None:
"""创建新的表达方式记录。
Args:
situation: 表达方式对应的使用情景。
style: 表达方式风格。
"""
content_list = [situation]
try:
with get_db_session() as db:
new_expr = Expression(
situation=situation,
style=style,
content_list=json.dumps(content_list),
count=1,
session_id=self.session_id,
last_active_time=datetime.now(),
)
db.add(new_expr)
db.flush()
except Exception as e:
logger.error(f"创建表达方式失败: {e}")
async def _update_existing_expression(self, expr: "MaiExpression", situation: str, use_llm_summary: bool = True):
expr.content.append(situation)
expr.count += 1
expr.checked = False # count 增加时重置 checked 为 False
expr.last_active_time = datetime.now()
if use_llm_summary:
# 相似匹配时,使用 LLM 重新组合 situation
new_situation = await self._compose_situation_text(expr.content)
if new_situation:
expr.situation = new_situation
try:
with get_db_session() as session:
if expr.item_id is None:
raise ValueError("表达方式对象缺少 item_id无法更新数据库记录")
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
if db_expr := session.exec(statement).first():
db_expr.content_list = json.dumps(expr.content)
db_expr.count = expr.count
db_expr.checked = expr.checked
db_expr.last_active_time = expr.last_active_time
db_expr.situation = expr.situation # 更新 situation
session.add(db_expr)
else:
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新表达方式失败: {e}")
# count 增加后,立即进行一次检查
await self._check_expression(expr)
# ====== 概括方法 ======
async def _compose_situation_text(self, content_list: List[str]) -> Optional[str]:
texts = [c.strip() for c in content_list if c.strip()]
if not texts:
return None
description = "\n".join(f"- {s}" for s in texts[-10:]) # 只取最近10条进行概括
prompt = (
"请阅读以下多个聊天情境描述并将它们概括成一句简短的话长度不超过20个字保留共同特点\n"
f"{description}\n"
"只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
if summary := summary.strip():
return summary
except Exception as e:
logger.error(f"使用 LLM 生成表达方式概括失败: {e}")
return None
async def _check_expression(self, expr: "MaiExpression"):
"""
检查表达方式(在 count 增加后调用)
Args:
expr (MaiExpression): 要检查的表达方式对象
"""
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自我反思功能未启用,跳过检查")
return
suitable, reason, error = await check_expression_suitability(expr.situation, expr.style)
if error:
logger.error(f"检查表达方式时发生错误: {error}")
return
expr.checked = True
expr.rejected = not suitable
try:
with get_db_session() as session:
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
if db_expr := session.exec(statement).first():
db_expr.checked = expr.checked
db_expr.rejected = expr.rejected
session.add(db_expr)
else:
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新检查结果")
except Exception as e:
logger.error(f"更新表达方式检查结果失败: {e}")
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式检查完成 [ID: {expr.item_id}] - {status} | "
f"Situation: {expr.situation[:30]}... | "
f"Style: {expr.style[:30]}... | "
f"Reason: {reason[:50] if reason else ''}..."
)
def _find_similar_expression(
self, situation: str, similarity_threshold: float = 0.75
) -> Optional[Tuple[MaiExpression, float]]:
"""在数据库中查找相似的表达方式。
Args:
situation: 当前待匹配的情景描述。
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
Returns:
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
``(表达方式对象, 相似度)``;否则返回 ``None``。
"""
try:
with get_db_session(auto_commit=False) as session:
statement = select(Expression).filter_by(session_id=self.session_id)
expressions = session.exec(statement).all()
best_match: Optional[MaiExpression] = None
best_similarity = 0.0
for db_expression in expressions:
expression = MaiExpression.from_db_instance(db_expression)
candidate_situations = [expression.situation, *expression.content]
for candidate_situation in candidate_situations:
normalized_candidate_situation = candidate_situation.strip()
if not normalized_candidate_situation:
continue
similarity = difflib.SequenceMatcher(
None,
situation,
normalized_candidate_situation,
).ratio()
if similarity > similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expression
if best_match:
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
return best_match, best_similarity
except Exception as e:
logger.error(f"查找相似表达方式失败: {e}")
return None

View File

@@ -0,0 +1,35 @@
from typing import Any, Dict, Optional
from src.manager.local_store_manager import local_storage
def _review_key(expression_id: int) -> str:
return f"expression_review:{expression_id}"
def get_review_state(expression_id: Optional[int]) -> Dict[str, Any]:
if expression_id is None:
return {"checked": False, "rejected": False, "modified_by": None}
value = local_storage[_review_key(expression_id)]
if isinstance(value, dict):
return {
"checked": bool(value.get("checked", False)),
"rejected": bool(value.get("rejected", False)),
"modified_by": value.get("modified_by"),
}
return {"checked": False, "rejected": False, "modified_by": None}
def set_review_state(
expression_id: Optional[int],
checked: bool,
rejected: bool,
modified_by: Optional[str],
) -> None:
if expression_id is None:
return
local_storage[_review_key(expression_id)] = {
"checked": checked,
"rejected": rejected,
"modified_by": modified_by,
}

View File

@@ -0,0 +1,456 @@
import json
import time
from typing import List, Dict, Optional, Any, Tuple
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
from src.learners.learner_utils_old import weighted_sample
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
)
def can_use_expression_for_chat(self, chat_id: str) -> bool:
"""
检查指定聊天流是否允许使用表达
Args:
chat_id: 聊天流ID
Returns:
bool: 是否允许使用表达
"""
try:
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id)
return use_expression
except Exception as e:
logger.error(f"检查表达使用权限失败: {e}")
return False
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
"""解析'platform:id:type'为chat_id直接使用 ChatManager 提供的接口"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
return SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
)
except Exception:
return None
def get_related_chat_ids(self, chat_id: str) -> List[str]:
"""根据expression_groups配置获取与当前chat_id相关的所有chat_id包括自身"""
groups = global_config.expression.expression_groups
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组则返回所有可用的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
return list(all_chat_ids) if all_chat_ids else [chat_id]
# 否则使用现有的组逻辑
for group in groups:
group_chat_ids = []
for stream_config_str in group:
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
return [chat_id]
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
简单模式:只选择 count > 1 的项目要求至少有10个才进行选择随机选5个不进行LLM选择
Args:
chat_id: 聊天流ID
max_num: 最大选择数量此参数在此模式下不使用固定选择5个
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
}
for expr in style_query
]
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
min_required = 8
if len(style_exprs) < min_required:
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
if not style_exprs:
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
# 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程
fallback_num = min(3, max_num) if max_num > 0 else 3
if fallback_selected := self._random_expressions(chat_id, fallback_num):
self.update_expressions_last_active_time(fallback_selected)
selected_ids = [expr["id"] for expr in fallback_selected]
logger.info(
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
)
return fallback_selected, selected_ids
return [], []
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
f"简单模式降级为随机选择 3 个"
)
select_count = min(3, len(style_exprs))
else:
# 高 count 数量达标时,固定选择 5 个
select_count = 5
import random
selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time
if selected_style:
self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style]
logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids
except Exception as e:
logger.error(f"简单模式选择表达方式失败: {e}")
return [], []
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
Args:
chat_id: 聊天室ID
total_num: 需要选择的数量
Returns:
List[Dict[str, Any]]: 随机选择的表达方式列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
}
for expr in style_query
]
# 随机抽样
return weighted_sample(style_exprs, total_num) if style_exprs else []
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")
return []
async def select_suitable_expressions(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
reply_reason: planner给出的回复理由
think_level: 思考级别0/1
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
# 检查是否允许在此聊天流中使用表达
if not self.can_use_expression_for_chat(chat_id):
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
return [], []
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic(
self,
chat_id: str,
chat_info: str,
max_num: int = 10,
target_message: Optional[str] = None,
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
Args:
chat_id: 聊天流ID
chat_info: 聊天内容信息
max_num: 最大选择数量
target_message: 目标消息内容
reply_reason: planner给出的回复理由
think_level: 思考级别0/1
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择
if think_level == 0:
return self._select_expressions_simple(chat_id, max_num)
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
# 如果 expression_checked_only 为 True则只选择 checked=True 且 rejected=False 的
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
if global_config.expression.expression_checked_only:
base_conditions = base_conditions & (Expression.checked)
style_query = Expression.select().where(base_conditions)
all_style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
}
for expr in style_query
]
# 分离 count > 1 和 count <= 1 的表达方式
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
# 根据 think_level 设置要求(仅支持 0/10 已在上方返回)
min_high_count = 10
min_total_count = 10
select_high_count = 5
select_random_count = 5
# 检查数量要求
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
if len(high_count_exprs) < min_high_count:
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
)
high_count_valid = False
else:
high_count_valid = True
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
if len(all_style_exprs) < min_total_count:
logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], []
# 先选取高count的表达方式如果数量达标
if high_count_valid:
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
else:
selected_high = []
# 然后从所有表达方式中随机抽样(使用加权抽样)
remaining_num = select_random_count
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
# 合并候选池(去重,避免重复)
candidate_exprs = selected_high.copy()
candidate_ids = {expr["id"] for expr in candidate_exprs}
for expr in selected_random:
if expr["id"] not in candidate_ids:
candidate_exprs.append(expr)
candidate_ids.add(expr["id"])
# 打乱顺序避免高count的都在前面
import random
random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in candidate_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
if not all_expressions:
logger.warning("没有找到可用的表达方式")
return [], []
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
target_message_extra_block = ""
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
# 构建reply_reason块
if reply_reason:
reply_reason_block = f"你的回复理由是:{reply_reason}"
chat_context = ""
else:
reply_reason_block = ""
# 3. 构建prompt只包含情境不包含完整的表达方式
prompt_template = prompt_manager.get_prompt("expression_select")
prompt_template.add_context("bot_name", global_config.bot.nickname)
prompt_template.add_context("chat_observe_info", chat_context)
prompt_template.add_context("all_situations", all_situations_str)
prompt_template.add_context("max_num", str(max_num))
prompt_template.add_context("target_message", target_message_str)
prompt_template.add_context("target_message_extra_block", target_message_extra_block)
prompt_template.add_context("reply_reason_block", reply_reason_block)
prompt = await prompt_manager.render_prompt(prompt_template)
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt)
# print(content)
if not content:
logger.warning("LLM返回空结果")
return [], []
# 5. 解析结果
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict) or "selected_situations" not in result:
logger.error("LLM返回格式错误")
logger.info(f"LLM返回结果: \n{content}")
return [], []
selected_indices = result["selected_situations"]
# 根据索引获取完整的表达方式
valid_expressions: List[Dict[str, Any]] = []
selected_ids = []
for idx in selected_indices:
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
expression = all_expressions[idx - 1] # 索引从1开始
selected_ids.append(expression["id"])
valid_expressions.append(expression)
# 对选中的所有表达方式更新last_active_time
if valid_expressions:
self.update_expressions_last_active_time(valid_expressions)
logger.debug(f"{len(all_expressions)}个情境中选择了{len(valid_expressions)}")
return valid_expressions, selected_ids
except Exception as e:
logger.error(f"classic模式处理表达方式选择时出错: {e}")
return [], []
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
"""对一批表达方式更新last_active_time"""
if not expressions_to_update:
return
updates_by_key = {}
for expr in expressions_to_update:
source_id: str = expr.get("source_id") # type: ignore
situation: str = expr.get("situation") # type: ignore
style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
for chat_id, situation, style in updates_by_key:
query = Expression.select().where(
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
expr_obj.last_active_time = time.time()
expr_obj.save()
logger.debug("表达方式激活: 更新last_active_time in db")
try:
expression_selector = ExpressionSelector()
except Exception as e:
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@@ -0,0 +1,212 @@
from json_repair import repair_json
from typing import Tuple, Optional, List
import json
import re
from src.config.config import model_config
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
from src.common.logger import get_logger
logger = get_logger("expression_utils")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
Returns:
(suitable, reason, error) 元组,如果出错则 suitable 为 Falseerror 包含错误信息
"""
# 构建评估提示词
# 基础评估标准
base_criteria = [
"表达方式或言语风格是否与使用条件或使用情景匹配",
"允许部分语法错误或口头化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称",
]
if custom_criteria := global_config.expression.expression_auto_check_custom_criteria:
base_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)])
prompt_template = prompt_manager.get_prompt("expression_evaluation")
prompt_template.add_context("situation", situation)
prompt_template.add_context("style", style)
prompt_template.add_context("criteria_list", criteria_list)
prompt = await prompt_manager.render_prompt(prompt_template)
logger.info(f"正在评估表达方式: situation={situation}, style={style}")
response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024)
logger.debug(f"评估结果: {response}")
try:
evaluation = json.loads(response)
except json.JSONDecodeError:
try:
response_repaired = repair_json(response)
evaluation = json.loads(response_repaired)
except Exception as e:
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
except Exception as e:
return False, f"评估表达方式时发生错误: {e}", str(e)
try:
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
return False, f"评估结果格式错误: {e}", str(e)
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string and char in ["", ""]:
result.append('\\"')
else:
result.append(char)
i += 1
return "".join(result)
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
期望的 JSON 结构:
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
raw = match[1].strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = _try_parse(raw)
if parsed is None:
fixed = fix_chinese_quotes_in_json(raw)
parsed = _try_parse(fixed)
if parsed is None:
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return [], []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
continue
content = str(item.get("content", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
return expressions, jargon_entries
def is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话(单个汉字、英文或数字)
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
)
def _try_parse(text):
try:
return json.loads(text)
except Exception:
try:
repaired = repair_json(text)
return json.loads(repaired)
except Exception:
return None

View File

@@ -0,0 +1,86 @@
from typing import Optional, Dict, List
from sqlmodel import select, func as fn
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("jargon_explainer")
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True,
) -> List[Dict[str, str]]:
"""
搜索 jargon支持大小写不敏感和模糊搜索
Args:
keyword: 搜索关键词
chat_id: 可选的聊天 IDsession_id
- 如果开启了 all_global此参数被忽略查询所有 is_global=True 的记录
- 如果关闭了 all_global如果提供则优先搜索该聊天或 global 的 jargon
limit: 返回结果数量限制,默认 10
case_sensitive: 是否大小写敏感,默认 False不敏感
fuzzy: 是否模糊搜索,默认 True使用 LIKE 匹配)
Returns:
List[Dict[str, str]]: 包含 content, meaning 的字典列表
"""
if not keyword or not keyword.strip():
return []
keyword = keyword.strip()
# 构建搜索条件
if case_sensitive: # 大小写敏感
search_condition = Jargon.content.contains(keyword) if fuzzy else Jargon.content == keyword # type: ignore
else:
keyword_lower = keyword.lower()
search_condition = (
fn.LOWER(Jargon.content).contains(keyword_lower) if fuzzy else fn.LOWER(Jargon.content) == keyword_lower
)
# 根据 all_global 配置决定查询逻辑同时,限制结果数量(先多取一些,因为后面可能过滤)
if global_config.expression.all_global_jargon:
# 开启 all_global所有记录都是全局的查询所有 is_global=True 的记录(无视 chat_id
query = select(Jargon).where(search_condition, Jargon.is_global).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
else:
# 关闭 all_global查询所有记录chat_id 过滤在 Python 层面进行
query = select(Jargon).where(search_condition).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
# 执行查询并返回结果
results: List[Dict[str, str]] = []
with get_db_session() as session:
jargons = session.exec(query).all()
for jargon in jargons:
# 如果提供了 chat_id 且 all_global=False需要检查 session_id_dict 是否包含目标 chat_id
if chat_id and not global_config.expression.all_global_jargon and not jargon.is_global:
try: # 解析 session_id_dict
session_id_dict = json.loads(jargon.session_id_dict) if jargon.session_id_dict else {}
except (json.JSONDecodeError, TypeError):
session_id_dict = {}
logger.warning(
f"解析 session_id_dict 失败jargon_id={jargon.id},原始数据:{jargon.session_id_dict}"
)
# 检查是否包含目标 chat_id
if chat_id not in session_id_dict:
continue
# 只返回有 meaning 的记录
if not jargon.meaning.strip():
continue
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
# 达到限制数量后停止
if len(results) >= limit:
break
return results

View File

@@ -0,0 +1,344 @@
import re
import time
from typing import List, Dict, Optional, Any
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.prompt.prompt_manager import prompt_manager
from src.learners.jargon_miner_old import search_jargon
from src.learners.learner_utils_old import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon")
class JargonExplainer:
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="jargon.explain",
)
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
"""
通过直接匹配数据库中的jargon字符串来提取黑话
Args:
messages: 消息列表
Returns:
List[Dict[str, str]]: 提取到的黑话列表每个元素包含content
"""
start_time = time.time()
if not messages:
return []
# 收集所有消息的文本内容
message_texts: List[str] = []
for msg in messages:
# 跳过机器人自己的消息
if is_bot_message(msg):
continue
msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if msg_text:
message_texts.append(msg_text)
if not message_texts:
return []
# 合并所有消息文本
combined_text = " ".join(message_texts)
# 查询所有有meaning的jargon记录
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 根据all_global配置决定查询逻辑
if global_config.expression.all_global_jargon:
# 开启all_global只查询is_global=True的记录
query = query.where(Jargon.is_global)
else:
# 关闭all_global查询is_global=True或chat_id列表包含当前chat_id的记录
# 这里先查询所有然后在Python层面过滤
pass
# 按count降序排序优先匹配出现频率高的
query = query.order_by(Jargon.count.desc())
# 执行查询并匹配
matched_jargon: Dict[str, Dict[str, str]] = {}
query_time = time.time()
for jargon in query:
content = jargon.content or ""
if not content or not content.strip():
continue
# 跳过包含机器人昵称的词条
if contains_bot_self_name(content):
continue
# 检查chat_id如果all_global=False
if not global_config.expression.all_global_jargon:
if jargon.is_global:
# 全局黑话,包含
pass
else:
# 检查chat_id列表是否包含当前chat_id
chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, self.chat_id):
continue
# 在文本中查找匹配(大小写不敏感)
pattern = re.escape(content)
# 使用单词边界或中文字符边界来匹配,避免部分匹配
# 对于中文使用Unicode字符类对于英文使用单词边界
if re.search(r"[\u4e00-\u9fff]", content):
# 包含中文,使用更宽松的匹配
search_pattern = pattern
else:
# 纯英文/数字,使用单词边界
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, combined_text, re.IGNORECASE):
# 找到匹配,记录(去重)
if content not in matched_jargon:
matched_jargon[content] = {"content": content}
match_time = time.time()
total_time = match_time - start_time
query_duration = query_time - start_time
match_duration = match_time - query_time
logger.debug(
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
)
return list(matched_jargon.values())
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
"""
解释上下文中的黑话
Args:
messages: 消息列表
chat_context: 聊天上下文的文本表示
Returns:
Optional[str]: 黑话解释的概括文本如果没有黑话则返回None
"""
if not messages:
return None
# 直接匹配方式从数据库中查询jargon并在消息中匹配
jargon_entries = self.match_jargon_from_messages(messages)
if not jargon_entries:
return None
# 去重按content
unique_jargon: Dict[str, Dict[str, str]] = {}
for entry in jargon_entries:
content = entry["content"]
if content not in unique_jargon:
unique_jargon[content] = entry
jargon_list = list(unique_jargon.values())
logger.info(f"从上下文中提取到 {len(jargon_list)} 个黑话: {[j['content'] for j in jargon_list]}")
# 查询每个黑话的含义
jargon_explanations: List[str] = []
for entry in jargon_list:
content = entry["content"]
# 根据是否开启全局黑话,决定查询方式
if global_config.expression.all_global_jargon:
# 开启全局黑话查询所有is_global=True的记录
results = search_jargon(
keyword=content,
chat_id=None, # 不指定chat_id查询全局黑话
limit=1,
case_sensitive=False,
fuzzy=False, # 精确匹配
)
else:
# 关闭全局黑话:优先查询当前聊天或全局的黑话
results = search_jargon(
keyword=content,
chat_id=self.chat_id,
limit=1,
case_sensitive=False,
fuzzy=False, # 精确匹配
)
if results and len(results) > 0:
meaning = results[0].get("meaning", "").strip()
if meaning:
jargon_explanations.append(f"- {content}: {meaning}")
else:
logger.info(f"黑话 {content} 没有找到含义")
else:
logger.info(f"黑话 {content} 未在数据库中找到")
if not jargon_explanations:
logger.info("没有找到任何黑话的含义,跳过解释")
return None
# 拼接所有黑话解释
explanations_text = "\n".join(jargon_explanations)
# 使用LLM概括黑话解释
prompt_of_summarize = prompt_manager.get_prompt("jargon_explainer_summarize")
prompt_of_summarize.add_context("chat_context", lambda _: chat_context)
prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text)
summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize)
summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3)
if not summary:
# 如果LLM概括失败直接返回原始解释
return f"上下文中的黑话解释:\n{explanations_text}"
summary = summary.strip()
if not summary:
return f"上下文中的黑话解释:\n{explanations_text}"
return summary
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
"""
解释上下文中的黑话(便捷函数)
Args:
chat_id: 聊天ID
messages: 消息列表
chat_context: 聊天上下文的文本表示
Returns:
Optional[str]: 黑话解释的概括文本如果没有黑话则返回None
"""
explainer = JargonExplainer(chat_id)
return await explainer.explain_jargon(messages, chat_context)
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
"""直接在聊天文本中匹配已知的jargon返回出现过的黑话列表
Args:
chat_text: 要匹配的聊天文本
chat_id: 聊天ID
Returns:
List[str]: 匹配到的黑话列表
"""
if not chat_text or not chat_text.strip():
return []
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
if global_config.expression.all_global_jargon:
query = query.where(Jargon.is_global)
query = query.order_by(Jargon.count.desc())
matched: Dict[str, None] = {}
for jargon in query:
content = (jargon.content or "").strip()
if not content:
continue
if not global_config.expression.all_global_jargon and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id):
continue
pattern = re.escape(content)
if re.search(r"[\u4e00-\u9fff]", content):
search_pattern = pattern
else:
search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, chat_text, re.IGNORECASE):
matched[content] = None
logger.info(f"匹配到 {len(matched)} 个黑话")
return list(matched.keys())
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
"""对概念列表进行jargon检索
Args:
concepts: 概念列表
chat_id: 聊天ID
Returns:
str: 检索结果字符串
"""
if not concepts:
return ""
results = []
exact_matches = [] # 收集所有精确匹配的概念
for concept in concepts:
concept = concept.strip()
if not concept:
continue
# 先尝试精确匹配
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
is_fuzzy_match = False
# 如果精确匹配未找到,尝试模糊搜索
if not jargon_results:
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
is_fuzzy_match = True
if jargon_results:
# 找到结果
if is_fuzzy_match:
# 模糊匹配
output_parts = [f"未精确匹配到'{concept}'"]
for result in jargon_results:
found_content = result.get("content", "").strip()
meaning = result.get("meaning", "").strip()
if found_content and meaning:
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
results.append("\n".join(output_parts)) # 换行分隔每个jargon解释
logger.info(f"在jargon库中找到匹配模糊搜索: {concept},找到{len(jargon_results)}条结果")
else:
# 精确匹配
output_parts = []
for result in jargon_results:
meaning = result.get("meaning", "").strip()
if meaning:
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
# 换行分隔每个jargon解释
results.append("\n".join(output_parts) if len(output_parts) > 1 else output_parts[0])
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
else:
# 未找到,不返回占位信息,只记录日志
logger.info(f"在jargon库中未找到匹配: {concept}")
# 合并所有精确匹配的日志
if exact_matches:
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
if results:
return "你了解以下词语可能的含义:\n" + "\n".join(results) + "\n"
return ""

View File

@@ -0,0 +1,419 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TypedDict
import asyncio
import json
import random
from json_repair import repair_json
from sqlmodel import select
from src.common.data_models.jargon_data_model import MaiJargon
from src.common.database.database import get_db_session
from src.common.database.database_model import Jargon
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
from .expression_utils import is_single_char_jargon
logger = get_logger("jargon")
# TODO: 重构完LLM相关内容后替换成新的模型调用方式
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
class JargonEntry(TypedDict):
content: str
raw_content: Set[str]
class JargonMeaningEntry(TypedDict):
content: str
meaning: str
class JargonMiner:
def __init__(self, session_id: str, session_name: str) -> None:
self.session_id = session_id
self.session_name = session_name
# Cache 相关
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
def get_cached_jargons(self) -> List[str]:
"""获取缓存中的所有黑话列表"""
return list(self.cache.keys())
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
"""
对jargon进行含义推断
"""
content = jargon_obj.content
# 解析raw_content列表
raw_content_list = []
if raw_content_str := jargon_obj.raw_content:
try:
raw_content_list = json.loads(raw_content_str)
if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError):
raw_content_list = [raw_content_str] if raw_content_str else []
if not raw_content_list:
logger.warning(f"jargon {content} 没有raw_content跳过推断")
return
# 获取当前count和上一次的meaning
current_count = jargon_obj.count
previous_meaning = jargon_obj.meaning
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
if current_count in [24, 60, 100] and previous_meaning:
previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}"
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context")
prompt1_template.add_context("bot_name", global_config.bot.nickname)
prompt1_template.add_context("content", str(content))
prompt1_template.add_context("raw_content_list", raw_content_text)
prompt1_template.add_context("previous_meaning_section", previous_meaning_section)
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
prompt1 = await prompt_manager.render_prompt(prompt1_template)
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
if not llm_response_1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
# 解析推断1结果
inference1 = self._parse_result(llm_response_1)
if not inference1:
logger.warning(f"jargon {content} 推断1解析失败")
return
no_info = inference1.get("no_info", False)
meaning1: str = inference1.get("meaning", "").strip()
if no_info or not meaning1:
logger.info(f"jargon {content} 推断1表示信息不足无法推断放弃本次推断待下次更新")
# 更新最后一次判定的count值避免在同一阈值重复尝试
jargon_obj.last_inference_count = jargon_obj.count or 0
try:
self._modify_jargon_entry(jargon_obj)
except Exception as e:
logger.error(f"jargon {content} 推断1更新last_inference_count失败: {e}")
return
# 步骤2: 基于content-only进行推断
prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only")
prompt2_template.add_context("content", content)
prompt2 = await prompt_manager.render_prompt(prompt2_template)
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
if not llm_response_2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
# 解析推断2结果
inference2 = self._parse_result(llm_response_2)
if not inference2:
logger.warning(f"jargon {content} 推断2解析失败")
return
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
# 步骤3: 比较两个推断结果
prompt3_template = prompt_manager.get_prompt("jargon_compare_inference")
prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False))
prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False))
prompt3 = await prompt_manager.render_prompt(prompt3_template)
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
if not llm_response_3:
logger.warning(f"jargon {content} 比较失败:无响应")
return
comparison_result = self._parse_result(llm_response_3)
if not comparison_result:
logger.warning(f"jargon {content} 比较解析失败")
return
is_similar = comparison_result.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
try:
self._modify_jargon_entry(jargon_obj)
except Exception as e:
logger.error(f"jargon {content} 推断结果更新失败: {e}")
logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式
if is_jargon:
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
meaning = jargon_obj.meaning or "无详细说明"
is_global = jargon_obj.is_global # 是否为全局的
if is_global:
logger.info(f"[黑话]{content}的含义是 {meaning}")
else:
logger.info(f"[{self.session_name}]{content}的含义是 {meaning}")
else:
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
logger.info(f"[{self.session_name}]{content} 不是黑话")
async def process_extracted_entries(
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
处理已提取的黑话条目(从 expression_learner 路由过来的)
Args:
entries: 黑话条目列表
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
"""
if not entries:
return
merged_entries: Dict[str, JargonEntry] = {}
for entry in entries:
content = entry["content"].strip()
if person_name_filter and person_name_filter(content):
logger.info(f"条目 '{content}' 包含人物名称,已过滤")
continue
raw_list = entry["raw_content"] or set()
if content in merged_entries:
merged_entries[content]["raw_content"].update(raw_list)
else:
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
uniq_entries: List[JargonEntry] = list(merged_entries.values())
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_set = entry["raw_content"]
try:
with get_db_session(auto_commit=False) as session:
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
except Exception as e:
logger.error(f"查询黑话 '{content}' 失败: {e}")
continue
# 找匹配项
matched_jargon: Optional[Jargon] = None
for item in jargon_items:
if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以
matched_jargon = item
break
else:
# 检查列表是否包含目标session_id
if item.session_id_dict:
try:
session_id_dict = json.loads(item.session_id_dict)
if self.session_id in session_id_dict:
matched_jargon = item
break
except Exception as e:
logger.error(f"解析Jargon id={item.id} session_id_list失败: {e}")
continue
if matched_jargon:
# 已存在记录更新count和raw_content
self._update_jargon(matched_jargon, raw_content_set)
if self._should_infer_meaning(matched_jargon):
asyncio.create_task(self._infer_meaning_by_id(matched_jargon.id)) # type: ignore
updated += 1
else:
# 没找到匹配记录,创建新记录
is_global_new = global_config.expression.all_global_jargon
session_dict_str = json.dumps({self.session_id: 1})
new_jargon = Jargon(
content=content,
raw_content=json.dumps(list(raw_content_set), ensure_ascii=False),
session_id_dict=session_dict_str,
is_global=is_global_new,
count=1,
meaning="",
)
try:
with get_db_session() as session:
session.add(new_jargon)
session.flush()
saved += 1
self._add_to_cache(content)
except Exception as e:
logger.error(f"保存新黑话 '{content}' 失败: {e}")
continue
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
logger.info(f"[{self.session_name}]疑似黑话: {jargon_str}")
if saved or updated:
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated}session_id={self.session_id}")
def _add_to_cache(self, content: str):
"""将黑话内容添加到缓存,并维护缓存大小"""
content = content.strip()
if is_single_char_jargon(content):
return
if content in self.cache:
# 已存在,移动到末尾表示最近使用
self.cache.move_to_end(content)
else:
# 新内容,添加到缓存
self.cache[content] = None
# 如果超过限制,移除最旧的项
if len(self.cache) > self.cache_limit:
removed_content, _ = self.cache.popitem(last=False)
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
"""更新已有黑话记录并写回数据库。
Args:
db_jargon: 已命中的黑话 ORM 对象。
raw_content_set: 本次新增的原始上下文集合。
"""
db_jargon.count += 1
existing_raw_content: List[str] = []
if db_jargon.raw_content:
try:
existing_raw_content = json.loads(db_jargon.raw_content)
except Exception:
existing_raw_content = []
# 合并去重
merged_list = list(set(existing_raw_content).union(raw_content_set))
db_jargon.raw_content = json.dumps(merged_list, ensure_ascii=False)
session_id_dict: Dict[str, int] = json.loads(db_jargon.session_id_dict)
session_id_dict[self.session_id] = session_id_dict.get(self.session_id, 0) + 1
db_jargon.session_id_dict = json.dumps(session_id_dict)
# 开启all_global时确保记录标记为is_global=True
if global_config.expression.all_global_jargon:
db_jargon.is_global = True
try:
with get_db_session() as session:
if db_jargon.id is None:
raise ValueError("黑话记录缺少 id无法更新数据库")
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
if persisted_jargon := session.exec(statement).first():
persisted_jargon.count = db_jargon.count
persisted_jargon.raw_content = db_jargon.raw_content
persisted_jargon.session_id_dict = db_jargon.session_id_dict
persisted_jargon.is_global = db_jargon.is_global
session.add(persisted_jargon)
else:
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
except Exception as e:
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
def _parse_result(self, response: str) -> Optional[Dict[str, str]]:
try:
result = json.loads(response.strip())
except Exception:
try:
repaired = repair_json(response.strip())
result = json.loads(repaired)
except Exception as e2:
logger.error(f"推断结果解析失败: {e2}")
return None
if not isinstance(result, dict):
logger.warning("推断结果格式错误")
return None
return result
def _modify_jargon_entry(self, jargon_obj: MaiJargon) -> None:
with get_db_session() as session:
if not jargon_obj.item_id:
raise ValueError("jargon_obj must have item_id to update")
statement = select(Jargon).filter_by(id=jargon_obj.item_id).limit(1)
if db_record := session.exec(statement).first():
db_record.is_jargon = jargon_obj.is_jargon
db_record.meaning = jargon_obj.meaning
db_record.last_inference_count = jargon_obj.last_inference_count
db_record.is_complete = jargon_obj.is_complete
session.add(db_record)
def _should_infer_meaning(self, jargon_obj: Jargon) -> bool:
"""
判断是否需要进行含义推断
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
并且count必须大于last_inference_count避免重启后重复判定
如果is_complete为True不再进行推断
"""
# 如果已完成所有推断,不再推断
if jargon_obj.is_complete:
return False
count = jargon_obj.count or 0
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [2, 4, 8, 12, 24, 60, 100]
if count < thresholds[0]:
return False
# 如果count没有超过上次判定值不需要判定
if count <= last_inference:
return False
next_threshold = next(
(threshold for threshold in thresholds if threshold > last_inference),
None,
)
# 如果没有找到下一个阈值说明已经超过100不应该再推断
return False if next_threshold is None else count >= next_threshold
async def _infer_meaning_by_id(self, jargon_id: int):
jargon_obj: Optional[MaiJargon] = None
try:
with get_db_session() as session:
statement = select(Jargon).filter_by(id=jargon_id).limit(1)
if db_record := session.exec(statement).first():
jargon_obj = MaiJargon.from_db_instance(db_record)
except Exception as e:
logger.error(f"查询Jargon id={jargon_id}失败: {e}")
return
if jargon_obj:
await self.infer_meaning(jargon_obj)

View File

@@ -0,0 +1,134 @@
from json_repair import repair_json
from typing import List, Tuple
import re
import json
from src.common.logger import get_logger
logger = get_logger("learner_utils")
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string and char in ["", ""]:
result.append('\\"')
else:
result.append(char)
i += 1
return "".join(result)
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
期望的 JSON 结构:
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
raw = match[1].strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
try:
# 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
except Exception as parse_error:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
else:
repaired = repair_json(fixed_raw)
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
except Exception as fix_error:
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return [], []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return [], []
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
elif item.get("content"):
# 黑话条目(有 content 字段)
content = str(item.get("content", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
return expressions, jargon_entries

View File

@@ -0,0 +1,445 @@
import random
import json
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("learner_utils")
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间。
count越高权重越高但最多为基础权重的5倍。
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
weights = [1.0 for _ in counts]
else:
weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 4.0) # 1~5
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
"""
解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
"""
if not chat_id_value:
return []
# 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str):
# 尝试解析JSON
try:
parsed = json.loads(chat_id_value)
if isinstance(parsed, list):
# 新格式:已经是列表
return parsed
elif isinstance(parsed, str):
# 解析后还是字符串,说明是旧格式
return [[parsed, 1]]
else:
# 其他类型,当作旧格式处理
return [[str(chat_id_value), 1]]
except (json.JSONDecodeError, TypeError):
# 解析失败,当作旧格式(纯字符串)
return [[str(chat_id_value), 1]]
elif isinstance(chat_id_value, list):
# 已经是列表格式
return chat_id_value
else:
# 其他类型,转换为旧格式
return [[str(chat_id_value), 1]]
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
"""
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1
Returns:
List[List[Any]]: 更新后的chat_id列表
"""
item = _find_chat_id_item(chat_id_list, target_chat_id)
if item is not None:
# 找到匹配的chat_id增加计数
if len(item) >= 2:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
else:
item.append(increment)
else:
# 未找到,添加新条目
chat_id_list.append([target_chat_id, increment])
return chat_id_list
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
"""
在chat_id列表中查找匹配的项辅助函数
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
如果找到则返回匹配的项否则返回None
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return item
return None
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
"""
检查chat_id列表中是否包含指定的chat_id
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
bool: 如果包含则返回True
"""
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
def contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates)
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身。"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not user_id:
return False
known_accounts = {
str(getattr(bot_config, "qq_account", "") or "").strip(),
str(getattr(bot_config, "telegram_account", "") or "").strip(),
}
for platform in getattr(bot_config, "platforms", []) or []:
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
if account:
known_accounts.add(account)
return user_id in {account for account in known_accounts if account}
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
# """
# 构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出
# """
# if not messages or center_index < 0 or center_index >= len(messages):
# return None
# context_start = max(0, center_index - 3)
# context_end = min(len(messages), center_index + 1 + 3)
# context_messages = messages[context_start:context_end]
# if not context_messages:
# return None
# try:
# paragraph = build_readable_messages(
# messages=context_messages,
# replace_bot_name=True,
# timestamp_mode="relative",
# read_mark=0.0,
# truncate=False,
# show_actions=False,
# show_pic=True,
# message_id_list=None,
# remove_emoji_stickers=False,
# pic_single=True,
# )
# except Exception as e:
# logger.warning(f"构建上下文段落失败: {e}")
# return None
# paragraph = paragraph.strip()
# return paragraph or None
# def is_bot_message(msg: Any) -> bool:
# """判断消息是否来自机器人自身"""
# if msg is None:
# return False
# bot_config = getattr(global_config, "bot", None)
# if not bot_config:
# return False
# platform = (
# str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
# .strip()
# .lower()
# )
# user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
# if not platform or not user_id:
# return False
# platform_accounts = {}
# try:
# platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
# except Exception:
# platform_accounts = {}
# bot_accounts: Dict[str, str] = {}
# qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
# if qq_account:
# bot_accounts["qq"] = qq_account
# telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
# if telegram_account:
# bot_accounts["telegram"] = telegram_account
# for plat, account in platform_accounts.items():
# if account and plat not in bot_accounts:
# bot_accounts[plat] = account
# bot_account = bot_accounts.get(platform)
# return bool(bot_account and user_id == bot_account)
# def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# """
# 解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
# 期望的 JSON 结构:
# [
# {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
# {"content": "词条", "source_id": "12"}, // 黑话
# ...
# ]
# Returns:
# Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# 第一个列表是表达方式 (situation, style, source_id)
# 第二个列表是黑话 (content, source_id)
# """
# if not response:
# return [], []
# raw = response.strip()
# # 尝试提取 ```json 代码块
# json_block_pattern = r"```json\s*(.*?)\s*```"
# match = re.search(json_block_pattern, raw, re.DOTALL)
# if match:
# raw = match.group(1).strip()
# else:
# # 去掉可能存在的通用 ``` 包裹
# raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
# raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
# raw = raw.strip()
# parsed = None
# expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
# jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
# try:
# # 优先尝试直接解析
# if raw.startswith("[") and raw.endswith("]"):
# parsed = json.loads(raw)
# else:
# repaired = repair_json(raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as parse_error:
# # 如果解析失败,尝试修复中文引号问题
# # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
# try:
# def fix_chinese_quotes_in_json(text):
# """使用状态机修复 JSON 字符串值中的中文引号"""
# result = []
# i = 0
# in_string = False
# escape_next = False
# while i < len(text):
# char = text[i]
# if escape_next:
# # 当前字符是转义字符后的字符,直接添加
# result.append(char)
# escape_next = False
# i += 1
# continue
# if char == "\\":
# # 转义字符
# result.append(char)
# escape_next = True
# i += 1
# continue
# if char == '"' and not escape_next:
# # 遇到英文引号,切换字符串状态
# in_string = not in_string
# result.append(char)
# i += 1
# continue
# if in_string:
# # 在字符串值内部,将中文引号替换为转义的英文引号
# if char == '"': # 中文左引号 U+201C
# result.append('\\"')
# elif char == '"': # 中文右引号 U+201D
# result.append('\\"')
# else:
# result.append(char)
# else:
# # 不在字符串内,直接添加
# result.append(char)
# i += 1
# return "".join(result)
# fixed_raw = fix_chinese_quotes_in_json(raw)
# # 再次尝试解析
# if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
# parsed = json.loads(fixed_raw)
# else:
# repaired = repair_json(fixed_raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as fix_error:
# logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
# logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
# logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
# logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
# return [], []
# if isinstance(parsed, dict):
# parsed_list = [parsed]
# elif isinstance(parsed, list):
# parsed_list = parsed
# else:
# logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
# return [], []
# for item in parsed_list:
# if not isinstance(item, dict):
# continue
# # 检查是否是表达方式条目(有 situation 和 style
# situation = str(item.get("situation", "")).strip()
# style = str(item.get("style", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if situation and style and source_id:
# # 表达方式条目
# expressions.append((situation, style, source_id))
# elif item.get("content"):
# # 黑话条目(有 content 字段)
# content = str(item.get("content", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if content and source_id:
# jargon_entries.append((content, source_id))
# return expressions, jargon_entries