merge: sync upstream/r-dev and resolve real conflicts
This commit is contained in:
250
src/learners/expression_auto_check_task.py
Normal file
250
src/learners/expression_auto_check_task.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
表达方式自动检查定时任务
|
||||
|
||||
功能:
|
||||
1. 定期随机选取指定数量的表达方式
|
||||
2. 使用 LLM 进行评估
|
||||
3. 通过评估的:rejected=0, checked=1
|
||||
4. 未通过评估的:rejected=1, checked=1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
from sqlmodel import select
|
||||
|
||||
from src.learners.expression_review_store import get_review_state, set_review_state
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
|
||||
logger = get_logger("expression_auto_check_task")
|
||||
|
||||
|
||||
def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||
"""
|
||||
创建评估提示词
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
评估提示词
|
||||
"""
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
|
||||
# 从配置中获取额外的自定义标准
|
||||
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
|
||||
|
||||
# 合并所有评估标准
|
||||
all_criteria = base_criteria.copy()
|
||||
if custom_criteria:
|
||||
all_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
|
||||
|
||||
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
|
||||
使用条件或使用情景:{situation}
|
||||
表达方式或言语风格:{style}
|
||||
|
||||
请从以下方面进行评估:
|
||||
{criteria_list}
|
||||
|
||||
请以JSON格式输出评估结果:
|
||||
{{
|
||||
"suitable": true/false,
|
||||
"reason": "评估理由(如果不合适,请说明原因)"
|
||||
|
||||
}}
|
||||
如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。
|
||||
请严格按照JSON格式输出,不要包含其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
|
||||
|
||||
|
||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
try:
|
||||
prompt = create_evaluation_prompt(situation, style)
|
||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
|
||||
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||
)
|
||||
|
||||
logger.debug(f"LLM响应: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
import re
|
||||
|
||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||
if json_match:
|
||||
evaluation = json.loads(json_match.group())
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
|
||||
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
|
||||
return False, f"评估过程出错: {str(e)}", str(e)
|
||||
|
||||
|
||||
class ExpressionAutoCheckTask(AsyncTask):
|
||||
"""表达方式自动检查定时任务"""
|
||||
|
||||
def __init__(self):
|
||||
# 从配置中获取检查间隔和一次检查数量
|
||||
check_interval = global_config.expression.expression_auto_check_interval
|
||||
super().__init__(
|
||||
task_name="Expression Auto Check Task",
|
||||
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
||||
run_interval=check_interval,
|
||||
)
|
||||
|
||||
async def _select_expressions(self, count: int) -> List[Expression]:
|
||||
"""
|
||||
随机选择指定数量的未检查表达方式
|
||||
|
||||
Args:
|
||||
count: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
选中的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 这里只做查询,避免退出上下文时自动提交导致 ORM 实例过期。
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression)
|
||||
all_expressions = session.exec(statement).all()
|
||||
|
||||
unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]]
|
||||
|
||||
if not unevaluated_expressions:
|
||||
logger.info("没有未检查的表达方式")
|
||||
return []
|
||||
|
||||
# 随机选择指定数量
|
||||
selected_count = min(count, len(unevaluated_expressions))
|
||||
selected = random.sample(unevaluated_expressions, selected_count)
|
||||
|
||||
logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条")
|
||||
return selected
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"选择表达方式时出错: {e}")
|
||||
return []
|
||||
|
||||
async def _evaluate_expression(self, expression: Expression) -> bool:
|
||||
"""
|
||||
评估单个表达方式
|
||||
|
||||
Args:
|
||||
expression: 要评估的表达方式
|
||||
|
||||
Returns:
|
||||
True表示通过,False表示不通过
|
||||
"""
|
||||
|
||||
suitable, reason, error = await single_expression_check(
|
||||
expression.situation,
|
||||
expression.style,
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
try:
|
||||
set_review_state(expression.id, True, not suitable, "ai")
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
logger.info(
|
||||
f"表达方式评估完成 [ID: {expression.id}] - {status} | "
|
||||
f"Situation: {expression.situation}... | "
|
||||
f"Style: {expression.style}... | "
|
||||
f"Reason: {reason[:50]}..."
|
||||
)
|
||||
|
||||
if error:
|
||||
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
|
||||
|
||||
return suitable
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}")
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行检查任务"""
|
||||
try:
|
||||
# 检查是否启用自动检查
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自动检查未启用,跳过本次执行")
|
||||
return
|
||||
|
||||
check_count = global_config.expression.expression_auto_check_count
|
||||
if check_count <= 0:
|
||||
logger.warning(f"检查数量配置无效: {check_count},跳过本次执行")
|
||||
return
|
||||
|
||||
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
||||
|
||||
# 选择要检查的表达方式
|
||||
expressions = await self._select_expressions(check_count)
|
||||
|
||||
if not expressions:
|
||||
logger.info("没有需要检查的表达方式")
|
||||
return
|
||||
|
||||
# 逐个评估
|
||||
passed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for i, expression in enumerate(expressions, 1):
|
||||
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
|
||||
|
||||
if await self._evaluate_expression(expression):
|
||||
passed_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
# 避免请求过快
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
logger.info(
|
||||
f"表达方式自动检查完成: 总计 {len(expressions)} 条,通过 {passed_count} 条,不通过 {failed_count} 条"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)
|
||||
504
src/learners/expression_learner.py
Normal file
504
src/learners/expression_learner.py
Normal file
@@ -0,0 +1,504 @@
|
||||
from datetime import datetime
|
||||
from sqlmodel import select
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import asyncio
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.data_models.expression_data_model import MaiExpression
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
from .expression_utils import check_expression_suitability, parse_expression_response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from .jargon_miner import JargonMiner, JargonEntry
|
||||
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
express_learn_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="expression.learner")
|
||||
summary_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.summary")
|
||||
check_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression.check")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, session_id: str) -> None:
|
||||
self.session_id = session_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
# 消息缓存
|
||||
self._messages_cache: List["SessionMessage"] = []
|
||||
|
||||
def add_messages(self, messages: List["SessionMessage"]) -> None:
|
||||
"""添加消息到缓存"""
|
||||
self._messages_cache.extend(messages)
|
||||
|
||||
def get_cache_size(self) -> int:
|
||||
"""获取当前消息缓存的大小"""
|
||||
return len(self._messages_cache)
|
||||
|
||||
async def learn(self, jargon_miner: Optional["JargonMiner"] = None):
|
||||
"""学习主流程"""
|
||||
if not self._messages_cache:
|
||||
logger.debug("没有消息可供学习,跳过学习过程")
|
||||
return
|
||||
|
||||
# 构建可读消息
|
||||
readable_message, _, _ = await MessageUtils.build_readable_message(
|
||||
self._messages_cache,
|
||||
anonymize=True,
|
||||
show_lineno=True,
|
||||
extract_pictures=True,
|
||||
replace_bot_name=True,
|
||||
target_bot_name="SELF",
|
||||
)
|
||||
|
||||
# 准备提示词
|
||||
prompt_template = prompt_manager.get_prompt("learn_style")
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("chat_str", readable_message)
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
# 调用 LLM 学习表达方式
|
||||
try:
|
||||
response, _ = await express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错:{e}")
|
||||
return
|
||||
|
||||
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]]
|
||||
jargon_entries: List[Tuple[str, str]] # (content, source_id)
|
||||
expressions, jargon_entries = parse_expression_response(response)
|
||||
|
||||
# 从缓存中检查 jargon 是否出现在 messages 中
|
||||
if cached_jargon_entries := self._check_cached_jargons_in_messages(jargon_miner):
|
||||
# 合并缓存中的 jargon 条目(去重:如果 content 已存在则跳过)
|
||||
existing_contents = {content for content, _ in jargon_entries}
|
||||
for content, source_id in cached_jargon_entries:
|
||||
if content not in existing_contents:
|
||||
jargon_entries.append((content, source_id))
|
||||
existing_contents.add(content)
|
||||
logger.info(f"从缓存中检查到黑话:{content}")
|
||||
|
||||
# 检查表达方式数量,如果超过 20 个则放弃本次表达学习
|
||||
if len(expressions) > 20:
|
||||
logger.info(f"表达方式提取数量超过 20 个(实际{len(expressions)}个),放弃本次表达学习")
|
||||
expressions = []
|
||||
|
||||
# 检查黑话数量,如果超过 30 个则放弃本次黑话学习
|
||||
if len(jargon_entries) > 30:
|
||||
logger.info(f"黑话提取数量超过 30 个(实际{len(jargon_entries)}个),放弃本次黑话学习")
|
||||
jargon_entries = []
|
||||
|
||||
# 处理黑话条目,路由到 jargon_miner(即使没有表达方式也要处理黑话)
|
||||
# TODO: 检测是否开启了
|
||||
if jargon_entries:
|
||||
await self._process_jargon_entries(jargon_entries, jargon_miner)
|
||||
|
||||
# 如果没有表达方式,直接返回
|
||||
if not expressions:
|
||||
logger.info("解析后没有可用的表达方式")
|
||||
return
|
||||
|
||||
logger.info(f"学习的 expressions: {expressions}")
|
||||
logger.info(f"学习的 jargon_entries: {jargon_entries}")
|
||||
|
||||
# 过滤表达方式,根据 source_id 溯源并应用各种过滤规则
|
||||
learnt_expressions = self._filter_expressions(expressions)
|
||||
|
||||
if not learnt_expressions:
|
||||
logger.info("没有学习到表达风格")
|
||||
return
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = "\n".join(f"{situation}->{style}" for situation, style in learnt_expressions)
|
||||
logger.info(f"在 {self.session_id} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for situation, style in learnt_expressions:
|
||||
await self._upsert_expression_to_db(situation, style)
|
||||
|
||||
# ====== 黑话相关 ======
|
||||
def _check_cached_jargons_in_messages(self, jargon_miner: Optional["JargonMiner"] = None) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
检查缓存中的 jargon 是否出现在 messages 中
|
||||
|
||||
Args:
|
||||
jargon_miner: JargonMiner 实例,用于获取缓存的黑话
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 匹配到的黑话条目列表,每个元素是 (content, source_id)
|
||||
"""
|
||||
if not jargon_miner:
|
||||
return []
|
||||
|
||||
# 获取缓存的所有 jargon 实例
|
||||
cached_jargons = jargon_miner.get_cached_jargons()
|
||||
if not cached_jargons:
|
||||
return []
|
||||
|
||||
matched_entries: List[Tuple[str, str]] = []
|
||||
|
||||
for i, msg in enumerate(self._messages_cache):
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
|
||||
continue
|
||||
|
||||
# 获取消息文本
|
||||
msg_text = (msg.processed_plain_text or "").strip()
|
||||
|
||||
if not msg_text:
|
||||
continue
|
||||
|
||||
# 检查每个缓存中的 jargon 是否出现在消息文本中
|
||||
for jargon in cached_jargons:
|
||||
if not jargon or not jargon.strip():
|
||||
continue
|
||||
|
||||
jargon_content = jargon.strip()
|
||||
|
||||
# 使用正则匹配,考虑单词边界(类似 jargon_explainer 中的逻辑)
|
||||
pattern = re.escape(jargon_content)
|
||||
# 对于中文,使用更宽松的匹配;对于英文/数字,使用单词边界
|
||||
if re.search(r"[\u4e00-\u9fff]", jargon_content):
|
||||
# 包含中文,使用更宽松的匹配
|
||||
search_pattern = pattern
|
||||
else:
|
||||
# 纯英文/数字,使用单词边界
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, msg_text, re.IGNORECASE):
|
||||
# 找到匹配,构建条目(source_id 从 1 开始,因为 build_readable_message 的编号从 1 开始)
|
||||
source_id = str(i + 1)
|
||||
matched_entries.append((jargon_content, source_id))
|
||||
|
||||
return matched_entries
|
||||
|
||||
async def _process_jargon_entries(
|
||||
self, jargon_entries: List[Tuple[str, str]], jargon_miner: Optional["JargonMiner"] = None
|
||||
):
|
||||
"""
|
||||
处理从 expression learner 提取的黑话条目,路由到 jargon_miner
|
||||
|
||||
Args:
|
||||
jargon_entries: 黑话条目列表,每个元素是 (content, source_id)
|
||||
jargon_miner: JargonMiner 实例
|
||||
"""
|
||||
if not jargon_entries or not self._messages_cache:
|
||||
return
|
||||
|
||||
if not jargon_miner:
|
||||
logger.warning("缺少 JargonMiner 实例,无法处理黑话条目")
|
||||
return
|
||||
|
||||
# 构建黑话条目格式
|
||||
entries: List["JargonEntry"] = []
|
||||
|
||||
for content, source_id in jargon_entries:
|
||||
content = content.strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 过滤掉包含 SELF 的黑话,不学习
|
||||
if "SELF" in content:
|
||||
logger.info(f"跳过包含 SELF 的黑话:{content}")
|
||||
continue
|
||||
|
||||
# TODO: 多平台兼容
|
||||
# 检查是否包含机器人名称
|
||||
bot_nickname = global_config.bot.nickname
|
||||
if bot_nickname and bot_nickname in content:
|
||||
logger.info(f"跳过包含机器人昵称的黑话:{content}")
|
||||
continue
|
||||
|
||||
# 解析 source_id
|
||||
if not source_id.isdigit():
|
||||
logger.warning(f"黑话条目 source_id 无效:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
# build_readable_message 的编号从 1 开始
|
||||
line_index = int(source_id) - 1
|
||||
if line_index < 0 or line_index >= len(self._messages_cache):
|
||||
logger.warning(f"黑话条目 source_id 超出范围:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = self._messages_cache[line_index]
|
||||
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
# 构建上下文段落(取前后各 3 条消息)
|
||||
start_idx = max(0, line_index - 3)
|
||||
end_idx = min(len(self._messages_cache), line_index + 4)
|
||||
context_msgs = self._messages_cache[start_idx:end_idx]
|
||||
|
||||
context_paragraph = "\n".join(
|
||||
[f"[{i + 1}] {msg.processed_plain_text or ''}" for i, msg in enumerate(context_msgs)]
|
||||
)
|
||||
|
||||
if not context_paragraph:
|
||||
logger.warning(f"黑话条目上下文为空:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": {context_paragraph}}) # type: ignore
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
await jargon_miner.process_extracted_entries(entries)
|
||||
logger.info(f"成功处理 {len(entries)} 个黑话条目")
|
||||
|
||||
# ====== 过滤方法 ======
|
||||
def _filter_expressions(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str]]:
|
||||
"""
|
||||
过滤表达方式,移除不符合条件的条目
|
||||
|
||||
Args:
|
||||
expressions: 表达方式列表,每个元素是 (situation, style, source_id)
|
||||
|
||||
Returns:
|
||||
过滤后的表达方式列表,每个元素是 (situation, style)
|
||||
"""
|
||||
filtered_expressions: List[Tuple[str, str]] = []
|
||||
|
||||
# 准备机器人名称集合(用于过滤 style 与机器人名称重复的表达)
|
||||
# TODO: 完善这里的机器人名称检测逻辑(考虑别名、不同平台的名称等)
|
||||
banned_names: set[str] = set()
|
||||
bot_nickname = global_config.bot.nickname
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
if alias_stripped := alias.strip():
|
||||
banned_names.add(alias_stripped)
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = source_id.strip()
|
||||
if not source_id_str.isdigit():
|
||||
continue # 无效的来源行编号,跳过
|
||||
line_index = int(source_id_str) - 1 # build_readable_message 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(self._messages_cache):
|
||||
continue # 超出范围,跳过
|
||||
# 当前行的原始消息
|
||||
current_msg = self._messages_cache[line_index]
|
||||
# 过滤掉从 bot 自己发言中提取到的表达方式
|
||||
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
|
||||
continue
|
||||
# 过滤掉无上下文的表达方式
|
||||
context = (current_msg.processed_plain_text or "").strip()
|
||||
if not context:
|
||||
continue
|
||||
# 过滤掉包含 SELF 的内容(不学习)
|
||||
if "SELF" in situation or "SELF" in style or "SELF" in context:
|
||||
logger.info(f"跳过包含 SELF 的表达方式:situation={situation}, style={style}, source_id={source_id}")
|
||||
continue
|
||||
# 过滤掉 style 与机器人名称/昵称重复的表达
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() in banned_casefold:
|
||||
logger.debug(
|
||||
f"跳过 style 与机器人名称重复的表达方式:situation={situation}, style={style}, source_id={source_id}"
|
||||
)
|
||||
continue
|
||||
# 过滤掉包含 "[表情" 的内容
|
||||
if "[表情包" in situation or "[表情包" in style or "[表情包" in context:
|
||||
logger.info(f"跳过包含表情标记的表达方式:situation={situation}, style={style}, source_id={source_id}")
|
||||
continue
|
||||
# 过滤掉包含 "[图片" 的内容
|
||||
if "[图片" in situation or "[图片" in style or "[图片" in context:
|
||||
logger.info(f"跳过包含图片标记的表达方式:situation={situation}, style={style}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style))
|
||||
|
||||
return filtered_expressions
|
||||
|
||||
# ====== DB 操作相关 ======
|
||||
async def _upsert_expression_to_db(self, situation: str, style: str) -> None:
|
||||
"""将表达方式写入数据库,存在时更新,不存在时新增。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
expr, similarity = self._find_similar_expression(situation) or (None, 0)
|
||||
if expr:
|
||||
# 根据相似度决定是否使用 LLM 总结
|
||||
# 完全匹配(相似度 == 1.0)时不总结,相似匹配时总结
|
||||
use_llm_summary = similarity < 1.0
|
||||
await self._update_existing_expression(expr, situation, use_llm_summary=use_llm_summary)
|
||||
return
|
||||
# 没有找到匹配的记录,创建新记录
|
||||
self._create_expression(situation, style)
|
||||
|
||||
def _create_expression(self, situation: str, style: str) -> None:
|
||||
"""创建新的表达方式记录。
|
||||
|
||||
Args:
|
||||
situation: 表达方式对应的使用情景。
|
||||
style: 表达方式风格。
|
||||
"""
|
||||
content_list = [situation]
|
||||
try:
|
||||
with get_db_session() as db:
|
||||
new_expr = Expression(
|
||||
situation=situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list),
|
||||
count=1,
|
||||
session_id=self.session_id,
|
||||
last_active_time=datetime.now(),
|
||||
)
|
||||
db.add(new_expr)
|
||||
db.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"创建表达方式失败: {e}")
|
||||
|
||||
async def _update_existing_expression(self, expr: "MaiExpression", situation: str, use_llm_summary: bool = True):
|
||||
expr.content.append(situation)
|
||||
expr.count += 1
|
||||
expr.checked = False # count 增加时重置 checked 为 False
|
||||
expr.last_active_time = datetime.now()
|
||||
|
||||
if use_llm_summary:
|
||||
# 相似匹配时,使用 LLM 重新组合 situation
|
||||
new_situation = await self._compose_situation_text(expr.content)
|
||||
if new_situation:
|
||||
expr.situation = new_situation
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
if expr.item_id is None:
|
||||
raise ValueError("表达方式对象缺少 item_id,无法更新数据库记录")
|
||||
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
|
||||
if db_expr := session.exec(statement).first():
|
||||
db_expr.content_list = json.dumps(expr.content)
|
||||
db_expr.count = expr.count
|
||||
db_expr.checked = expr.checked
|
||||
db_expr.last_active_time = expr.last_active_time
|
||||
db_expr.situation = expr.situation # 更新 situation
|
||||
session.add(db_expr)
|
||||
else:
|
||||
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新")
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式失败: {e}")
|
||||
|
||||
# count 增加后,立即进行一次检查
|
||||
await self._check_expression(expr)
|
||||
|
||||
# ====== 概括方法 ======
|
||||
async def _compose_situation_text(self, content_list: List[str]) -> Optional[str]:
|
||||
texts = [c.strip() for c in content_list if c.strip()]
|
||||
if not texts:
|
||||
return None
|
||||
description = "\n".join(f"- {s}" for s in texts[-10:]) # 只取最近10条进行概括
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,长度不超过20个字,保留共同特点:\n"
|
||||
f"{description}\n"
|
||||
"只输出概括内容。"
|
||||
)
|
||||
try:
|
||||
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
if summary := summary.strip():
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"使用 LLM 生成表达方式概括失败: {e}")
|
||||
return None
|
||||
|
||||
async def _check_expression(self, expr: "MaiExpression"):
|
||||
"""
|
||||
检查表达方式(在 count 增加后调用)
|
||||
|
||||
Args:
|
||||
expr (MaiExpression): 要检查的表达方式对象
|
||||
"""
|
||||
if not global_config.expression.expression_self_reflect:
|
||||
logger.debug("表达方式自我反思功能未启用,跳过检查")
|
||||
return
|
||||
|
||||
suitable, reason, error = await check_expression_suitability(expr.situation, expr.style)
|
||||
if error:
|
||||
logger.error(f"检查表达方式时发生错误: {error}")
|
||||
return
|
||||
expr.checked = True
|
||||
expr.rejected = not suitable
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Expression).filter_by(id=expr.item_id).limit(1)
|
||||
if db_expr := session.exec(statement).first():
|
||||
db_expr.checked = expr.checked
|
||||
db_expr.rejected = expr.rejected
|
||||
session.add(db_expr)
|
||||
else:
|
||||
logger.warning(f"表达方式 ID {expr.item_id} 在数据库中未找到,无法更新检查结果")
|
||||
except Exception as e:
|
||||
logger.error(f"更新表达方式检查结果失败: {e}")
|
||||
|
||||
status = "通过" if suitable else "不通过"
|
||||
logger.info(
|
||||
f"表达方式检查完成 [ID: {expr.item_id}] - {status} | "
|
||||
f"Situation: {expr.situation[:30]}... | "
|
||||
f"Style: {expr.style[:30]}... | "
|
||||
f"Reason: {reason[:50] if reason else '无'}..."
|
||||
)
|
||||
|
||||
def _find_similar_expression(
|
||||
self, situation: str, similarity_threshold: float = 0.75
|
||||
) -> Optional[Tuple[MaiExpression, float]]:
|
||||
"""在数据库中查找相似的表达方式。
|
||||
|
||||
Args:
|
||||
situation: 当前待匹配的情景描述。
|
||||
similarity_threshold: 认定为相似表达方式的最低相似度阈值。
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[MaiExpression, float]]: 若找到最相似的表达方式,则返回
|
||||
``(表达方式对象, 相似度)``;否则返回 ``None``。
|
||||
"""
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(Expression).filter_by(session_id=self.session_id)
|
||||
expressions = session.exec(statement).all()
|
||||
|
||||
best_match: Optional[MaiExpression] = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for db_expression in expressions:
|
||||
expression = MaiExpression.from_db_instance(db_expression)
|
||||
candidate_situations = [expression.situation, *expression.content]
|
||||
for candidate_situation in candidate_situations:
|
||||
normalized_candidate_situation = candidate_situation.strip()
|
||||
if not normalized_candidate_situation:
|
||||
continue
|
||||
similarity = difflib.SequenceMatcher(
|
||||
None,
|
||||
situation,
|
||||
normalized_candidate_situation,
|
||||
).ratio()
|
||||
if similarity > similarity_threshold and similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = expression
|
||||
|
||||
if best_match:
|
||||
logger.debug(f"找到相似表达方式情景 [ID: {best_match.item_id}],相似度: {best_similarity:.2f}")
|
||||
return best_match, best_similarity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查找相似表达方式失败: {e}")
|
||||
return None
|
||||
35
src/learners/expression_review_store.py
Normal file
35
src/learners/expression_review_store.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _review_key(expression_id: int) -> str:
|
||||
return f"expression_review:{expression_id}"
|
||||
|
||||
|
||||
def get_review_state(expression_id: Optional[int]) -> Dict[str, Any]:
|
||||
if expression_id is None:
|
||||
return {"checked": False, "rejected": False, "modified_by": None}
|
||||
value = local_storage[_review_key(expression_id)]
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
"checked": bool(value.get("checked", False)),
|
||||
"rejected": bool(value.get("rejected", False)),
|
||||
"modified_by": value.get("modified_by"),
|
||||
}
|
||||
return {"checked": False, "rejected": False, "modified_by": None}
|
||||
|
||||
|
||||
def set_review_state(
|
||||
expression_id: Optional[int],
|
||||
checked: bool,
|
||||
rejected: bool,
|
||||
modified_by: Optional[str],
|
||||
) -> None:
|
||||
if expression_id is None:
|
||||
return
|
||||
local_storage[_review_key(expression_id)] = {
|
||||
"checked": checked,
|
||||
"rejected": rejected,
|
||||
"modified_by": modified_by,
|
||||
}
|
||||
456
src/learners/expression_selector.py
Normal file
456
src/learners/expression_selector.py
Normal file
@@ -0,0 +1,456 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.learner_utils_old import weighted_sample
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id,直接使用 ChatManager 提供的接口"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
return SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
简单模式:只选择 count > 1 的项目,要求至少有10个才进行选择,随机选5个,不进行LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
max_num: 最大选择数量(此参数在此模式下不使用,固定选择5个)
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 查询所有相关chat_id的表达方式,排除 rejected=1 的,且只选择 count > 1 的
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
|
||||
)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
|
||||
min_required = 8
|
||||
if len(style_exprs) < min_required:
|
||||
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
|
||||
if not style_exprs:
|
||||
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
|
||||
# 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程)
|
||||
fallback_num = min(3, max_num) if max_num > 0 else 3
|
||||
if fallback_selected := self._random_expressions(chat_id, fallback_num):
|
||||
self.update_expressions_last_active_time(fallback_selected)
|
||||
selected_ids = [expr["id"] for expr in fallback_selected]
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
|
||||
)
|
||||
return fallback_selected, selected_ids
|
||||
return [], []
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
|
||||
f"简单模式降级为随机选择 3 个"
|
||||
)
|
||||
select_count = min(3, len(style_exprs))
|
||||
else:
|
||||
# 高 count 数量达标时,固定选择 5 个
|
||||
select_count = 5
|
||||
import random
|
||||
|
||||
selected_style = random.sample(style_exprs, select_count)
|
||||
|
||||
# 更新last_active_time
|
||||
if selected_style:
|
||||
self.update_expressions_last_active_time(selected_style)
|
||||
|
||||
selected_ids = [expr["id"] for expr in selected_style]
|
||||
logger.debug(
|
||||
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)} 个"
|
||||
)
|
||||
return selected_style, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单模式选择表达方式失败: {e}")
|
||||
return [], []
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
return weighted_sample(style_exprs, total_num) if style_exprs else []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式,think_level={think_level}")
|
||||
return await self._select_expressions_classic(
|
||||
chat_id, chat_info, max_num, target_message, reply_reason, think_level
|
||||
)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
think_level: int = 1,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
think_level: 思考级别,0/1
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# think_level == 0: 只选择 count > 1 的项目,随机选10个,不进行LLM选择
|
||||
if think_level == 0:
|
||||
return self._select_expressions_simple(chat_id, max_num)
|
||||
|
||||
# think_level == 1: 先选高count,再从所有表达方式中随机抽样
|
||||
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
# 如果 expression_checked_only 为 True,则只选择 checked=True 且 rejected=False 的
|
||||
base_conditions = (Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
if global_config.expression.expression_checked_only:
|
||||
base_conditions = base_conditions & (Expression.checked)
|
||||
style_query = Expression.select().where(base_conditions)
|
||||
|
||||
all_style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 分离 count > 1 和 count <= 1 的表达方式
|
||||
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
|
||||
|
||||
# 根据 think_level 设置要求(仅支持 0/1,0 已在上方返回)
|
||||
min_high_count = 10
|
||||
min_total_count = 10
|
||||
select_high_count = 5
|
||||
select_random_count = 5
|
||||
|
||||
# 检查数量要求
|
||||
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
|
||||
if len(high_count_exprs) < min_high_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
|
||||
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
|
||||
)
|
||||
high_count_valid = False
|
||||
else:
|
||||
high_count_valid = True
|
||||
|
||||
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
|
||||
if len(all_style_exprs) < min_total_count:
|
||||
logger.info(
|
||||
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
|
||||
)
|
||||
return [], []
|
||||
|
||||
# 先选取高count的表达方式(如果数量达标)
|
||||
if high_count_valid:
|
||||
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
|
||||
else:
|
||||
selected_high = []
|
||||
|
||||
# 然后从所有表达方式中随机抽样(使用加权抽样)
|
||||
remaining_num = select_random_count
|
||||
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
|
||||
|
||||
# 合并候选池(去重,避免重复)
|
||||
candidate_exprs = selected_high.copy()
|
||||
candidate_ids = {expr["id"] for expr in candidate_exprs}
|
||||
for expr in selected_random:
|
||||
if expr["id"] not in candidate_ids:
|
||||
candidate_exprs.append(expr)
|
||||
candidate_ids.add(expr["id"])
|
||||
|
||||
# 打乱顺序,避免高count的都在前面
|
||||
import random
|
||||
|
||||
random.shuffle(candidate_exprs)
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in candidate_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
chat_context = ""
|
||||
else:
|
||||
reply_reason_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt_template = prompt_manager.get_prompt("expression_select")
|
||||
prompt_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt_template.add_context("chat_observe_info", chat_context)
|
||||
prompt_template.add_context("all_situations", all_situations_str)
|
||||
prompt_template.add_context("max_num", str(max_num))
|
||||
prompt_template.add_context("target_message", target_message_str)
|
||||
prompt_template.add_context("target_message_extra_block", target_message_extra_block)
|
||||
prompt_template.add_context("reply_reason_block", reply_reason_block)
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(prompt)
|
||||
# print(content)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.debug(f"从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
212
src/learners/expression_utils.py
Normal file
212
src/learners/expression_utils.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from json_repair import repair_json
|
||||
from typing import Tuple, Optional, List
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.config.config import global_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("expression_utils")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
|
||||
|
||||
|
||||
async def check_expression_suitability(situation: str, style: str) -> Tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
执行单次LLM评估
|
||||
|
||||
Args:
|
||||
situation: 情境
|
||||
style: 风格
|
||||
|
||||
Returns:
|
||||
(suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息
|
||||
"""
|
||||
# 构建评估提示词
|
||||
# 基础评估标准
|
||||
base_criteria = [
|
||||
"表达方式或言语风格是否与使用条件或使用情景匹配",
|
||||
"允许部分语法错误或口头化或缺省出现",
|
||||
"表达方式不能太过特指,需要具有泛用性",
|
||||
"一般不涉及具体的人名或名称",
|
||||
]
|
||||
|
||||
if custom_criteria := global_config.expression.expression_auto_check_custom_criteria:
|
||||
base_criteria.extend(custom_criteria)
|
||||
|
||||
# 构建评估标准列表字符串
|
||||
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(base_criteria)])
|
||||
|
||||
prompt_template = prompt_manager.get_prompt("expression_evaluation")
|
||||
prompt_template.add_context("situation", situation)
|
||||
prompt_template.add_context("style", style)
|
||||
prompt_template.add_context("criteria_list", criteria_list)
|
||||
|
||||
prompt = await prompt_manager.render_prompt(prompt_template)
|
||||
|
||||
logger.info(f"正在评估表达方式: situation={situation}, style={style}")
|
||||
|
||||
response, _ = await judge_llm.generate_response_async(prompt=prompt, temperature=0.6, max_tokens=1024)
|
||||
|
||||
logger.debug(f"评估结果: {response}")
|
||||
|
||||
try:
|
||||
evaluation = json.loads(response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
response_repaired = repair_json(response)
|
||||
evaluation = json.loads(response_repaired)
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法解析LLM响应为JSON: {response}") from e
|
||||
except Exception as e:
|
||||
return False, f"评估表达方式时发生错误: {e}", str(e)
|
||||
try:
|
||||
suitable = evaluation.get("suitable", False)
|
||||
reason = evaluation.get("reason", "未提供理由")
|
||||
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
|
||||
return suitable, reason, None
|
||||
except Exception as e:
|
||||
return False, f"评估结果格式错误: {e}", str(e)
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
if in_string and char in ["“", "”"]:
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
raw = match[1].strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = _try_parse(raw)
|
||||
if parsed is None:
|
||||
fixed = fix_chinese_quotes_in_json(raw)
|
||||
parsed = _try_parse(fixed)
|
||||
if parsed is None:
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
continue
|
||||
content = str(item.get("content", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
|
||||
|
||||
def is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
"\u4e00" <= char <= "\u9fff" # 汉字
|
||||
or "a" <= char <= "z" # 小写字母
|
||||
or "A" <= char <= "Z" # 大写字母
|
||||
or "0" <= char <= "9" # 数字
|
||||
)
|
||||
|
||||
|
||||
def _try_parse(text):
|
||||
try:
|
||||
return json.loads(text)
|
||||
except Exception:
|
||||
try:
|
||||
repaired = repair_json(text)
|
||||
return json.loads(repaired)
|
||||
except Exception:
|
||||
return None
|
||||
86
src/learners/jargon_explainer.py
Normal file
86
src/learners/jargon_explainer.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from typing import Optional, Dict, List
|
||||
from sqlmodel import select, func as fn
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("jargon_explainer")
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str,
|
||||
chat_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
case_sensitive: bool = False,
|
||||
fuzzy: bool = True,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索 jargon,支持大小写不敏感和模糊搜索
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
chat_id: 可选的聊天 ID(session_id)
|
||||
- 如果开启了 all_global:此参数被忽略,查询所有 is_global=True 的记录
|
||||
- 如果关闭了 all_global:如果提供则优先搜索该聊天或 global 的 jargon
|
||||
limit: 返回结果数量限制,默认 10
|
||||
case_sensitive: 是否大小写敏感,默认 False(不敏感)
|
||||
fuzzy: 是否模糊搜索,默认 True(使用 LIKE 匹配)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 包含 content, meaning 的字典列表
|
||||
"""
|
||||
if not keyword or not keyword.strip():
|
||||
return []
|
||||
|
||||
keyword = keyword.strip()
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive: # 大小写敏感
|
||||
search_condition = Jargon.content.contains(keyword) if fuzzy else Jargon.content == keyword # type: ignore
|
||||
else:
|
||||
keyword_lower = keyword.lower()
|
||||
search_condition = (
|
||||
fn.LOWER(Jargon.content).contains(keyword_lower) if fuzzy else fn.LOWER(Jargon.content) == keyword_lower
|
||||
)
|
||||
|
||||
# 根据 all_global 配置决定查询逻辑同时,限制结果数量(先多取一些,因为后面可能过滤)
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启 all_global:所有记录都是全局的,查询所有 is_global=True 的记录(无视 chat_id)
|
||||
query = select(Jargon).where(search_condition, Jargon.is_global).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
|
||||
else:
|
||||
# 关闭 all_global:查询所有记录,chat_id 过滤在 Python 层面进行
|
||||
query = select(Jargon).where(search_condition).order_by(Jargon.count.desc()).limit(limit * 2) # type: ignore
|
||||
|
||||
# 执行查询并返回结果
|
||||
results: List[Dict[str, str]] = []
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(query).all()
|
||||
|
||||
for jargon in jargons:
|
||||
# 如果提供了 chat_id 且 all_global=False,需要检查 session_id_dict 是否包含目标 chat_id
|
||||
if chat_id and not global_config.expression.all_global_jargon and not jargon.is_global:
|
||||
try: # 解析 session_id_dict
|
||||
session_id_dict = json.loads(jargon.session_id_dict) if jargon.session_id_dict else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
session_id_dict = {}
|
||||
logger.warning(
|
||||
f"解析 session_id_dict 失败,jargon_id={jargon.id},原始数据:{jargon.session_id_dict}"
|
||||
)
|
||||
|
||||
# 检查是否包含目标 chat_id
|
||||
if chat_id not in session_id_dict:
|
||||
continue
|
||||
# 只返回有 meaning 的记录
|
||||
if not jargon.meaning.strip():
|
||||
continue
|
||||
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
# 达到限制数量后停止
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
344
src/learners/jargon_explainer_old.py
Normal file
344
src/learners/jargon_explainer_old.py
Normal file
@@ -0,0 +1,344 @@
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.learners.jargon_miner_old import search_jargon
|
||||
from src.learners.learner_utils_old import (
|
||||
is_bot_message,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
)
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
class JargonExplainer:
|
||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.tool_use,
|
||||
request_type="jargon.explain",
|
||||
)
|
||||
|
||||
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
通过直接匹配数据库中的jargon字符串来提取黑话
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 收集所有消息的文本内容
|
||||
message_texts: List[str] = []
|
||||
for msg in messages:
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_message(msg):
|
||||
continue
|
||||
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if msg_text:
|
||||
message_texts.append(msg_text)
|
||||
|
||||
if not message_texts:
|
||||
return []
|
||||
|
||||
# 合并所有消息文本
|
||||
combined_text = " ".join(message_texts)
|
||||
|
||||
# 查询所有有meaning的jargon记录
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:只查询is_global=True的记录
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
# 关闭all_global:查询is_global=True或chat_id列表包含当前chat_id的记录
|
||||
# 这里先查询所有,然后在Python层面过滤
|
||||
pass
|
||||
|
||||
# 按count降序排序,优先匹配出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
# 执行查询并匹配
|
||||
matched_jargon: Dict[str, Dict[str, str]] = {}
|
||||
query_time = time.time()
|
||||
|
||||
for jargon in query:
|
||||
content = jargon.content or ""
|
||||
if not content or not content.strip():
|
||||
continue
|
||||
|
||||
# 跳过包含机器人昵称的词条
|
||||
if contains_bot_self_name(content):
|
||||
continue
|
||||
|
||||
# 检查chat_id(如果all_global=False)
|
||||
if not global_config.expression.all_global_jargon:
|
||||
if jargon.is_global:
|
||||
# 全局黑话,包含
|
||||
pass
|
||||
else:
|
||||
# 检查chat_id列表是否包含当前chat_id
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
continue
|
||||
|
||||
# 在文本中查找匹配(大小写不敏感)
|
||||
pattern = re.escape(content)
|
||||
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
||||
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
# 包含中文,使用更宽松的匹配
|
||||
search_pattern = pattern
|
||||
else:
|
||||
# 纯英文/数字,使用单词边界
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
||||
# 找到匹配,记录(去重)
|
||||
if content not in matched_jargon:
|
||||
matched_jargon[content] = {"content": content}
|
||||
|
||||
match_time = time.time()
|
||||
total_time = match_time - start_time
|
||||
query_duration = query_time - start_time
|
||||
match_duration = match_time - query_time
|
||||
|
||||
logger.debug(
|
||||
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
|
||||
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
|
||||
)
|
||||
|
||||
return list(matched_jargon.values())
|
||||
|
||||
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
chat_context: 聊天上下文的文本表示
|
||||
|
||||
Returns:
|
||||
Optional[str]: 黑话解释的概括文本,如果没有黑话则返回None
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 直接匹配方式:从数据库中查询jargon并在消息中匹配
|
||||
jargon_entries = self.match_jargon_from_messages(messages)
|
||||
|
||||
if not jargon_entries:
|
||||
return None
|
||||
|
||||
# 去重(按content)
|
||||
unique_jargon: Dict[str, Dict[str, str]] = {}
|
||||
for entry in jargon_entries:
|
||||
content = entry["content"]
|
||||
if content not in unique_jargon:
|
||||
unique_jargon[content] = entry
|
||||
|
||||
jargon_list = list(unique_jargon.values())
|
||||
logger.info(f"从上下文中提取到 {len(jargon_list)} 个黑话: {[j['content'] for j in jargon_list]}")
|
||||
|
||||
# 查询每个黑话的含义
|
||||
jargon_explanations: List[str] = []
|
||||
for entry in jargon_list:
|
||||
content = entry["content"]
|
||||
|
||||
# 根据是否开启全局黑话,决定查询方式
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启全局黑话:查询所有is_global=True的记录
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=None, # 不指定chat_id,查询全局黑话
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=False, # 精确匹配
|
||||
)
|
||||
else:
|
||||
# 关闭全局黑话:优先查询当前聊天或全局的黑话
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=self.chat_id,
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=False, # 精确匹配
|
||||
)
|
||||
|
||||
if results and len(results) > 0:
|
||||
meaning = results[0].get("meaning", "").strip()
|
||||
if meaning:
|
||||
jargon_explanations.append(f"- {content}: {meaning}")
|
||||
else:
|
||||
logger.info(f"黑话 {content} 没有找到含义")
|
||||
else:
|
||||
logger.info(f"黑话 {content} 未在数据库中找到")
|
||||
|
||||
if not jargon_explanations:
|
||||
logger.info("没有找到任何黑话的含义,跳过解释")
|
||||
return None
|
||||
|
||||
# 拼接所有黑话解释
|
||||
explanations_text = "\n".join(jargon_explanations)
|
||||
|
||||
# 使用LLM概括黑话解释
|
||||
prompt_of_summarize = prompt_manager.get_prompt("jargon_explainer_summarize")
|
||||
prompt_of_summarize.add_context("chat_context", lambda _: chat_context)
|
||||
prompt_of_summarize.add_context("jargon_explanations", lambda _: explanations_text)
|
||||
summarize_prompt = await prompt_manager.render_prompt(prompt_of_summarize)
|
||||
|
||||
summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3)
|
||||
if not summary:
|
||||
# 如果LLM概括失败,直接返回原始解释
|
||||
return f"上下文中的黑话解释:\n{explanations_text}"
|
||||
|
||||
summary = summary.strip()
|
||||
if not summary:
|
||||
return f"上下文中的黑话解释:\n{explanations_text}"
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话(便捷函数)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
messages: 消息列表
|
||||
chat_context: 聊天上下文的文本表示
|
||||
|
||||
Returns:
|
||||
Optional[str]: 黑话解释的概括文本,如果没有黑话则返回None
|
||||
"""
|
||||
explainer = JargonExplainer(chat_id)
|
||||
return await explainer.explain_jargon(messages, chat_context)
|
||||
|
||||
|
||||
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表
|
||||
|
||||
Args:
|
||||
chat_text: 要匹配的聊天文本
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
List[str]: 匹配到的黑话列表
|
||||
"""
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.expression.all_global_jargon:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
matched: Dict[str, None] = {}
|
||||
|
||||
for jargon in query:
|
||||
content = (jargon.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.expression.all_global_jargon and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
pattern = re.escape(content)
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, chat_text, re.IGNORECASE):
|
||||
matched[content] = None
|
||||
|
||||
logger.info(f"匹配到 {len(matched)} 个黑话")
|
||||
|
||||
return list(matched.keys())
|
||||
|
||||
|
||||
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
concepts: 概念列表
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 检索结果字符串
|
||||
"""
|
||||
if not concepts:
|
||||
return ""
|
||||
|
||||
results = []
|
||||
exact_matches = [] # 收集所有精确匹配的概念
|
||||
for concept in concepts:
|
||||
concept = concept.strip()
|
||||
if not concept:
|
||||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
# 找到结果
|
||||
if is_fuzzy_match:
|
||||
# 模糊匹配
|
||||
output_parts = [f"未精确匹配到'{concept}'"]
|
||||
for result in jargon_results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append("\n".join(output_parts)) # 换行分隔每个jargon解释
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
output_parts = []
|
||||
for result in jargon_results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
# 换行分隔每个jargon解释
|
||||
results.append("\n".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
logger.info(f"在jargon库中未找到匹配: {concept}")
|
||||
|
||||
# 合并所有精确匹配的日志
|
||||
if exact_matches:
|
||||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "你了解以下词语可能的含义:\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
419
src/learners/jargon_miner.py
Normal file
419
src/learners/jargon_miner.py
Normal file
@@ -0,0 +1,419 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Set, TypedDict
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
|
||||
from json_repair import repair_json
|
||||
from sqlmodel import select
|
||||
|
||||
from src.common.data_models.jargon_data_model import MaiJargon
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
|
||||
from .expression_utils import is_single_char_jargon
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
# TODO: 重构完LLM相关内容后,替换成新的模型调用方式
|
||||
llm_extract = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.extract")
|
||||
llm_inference = LLMRequest(model_set=model_config.model_task_config.utils, request_type="jargon.inference")
|
||||
|
||||
|
||||
class JargonEntry(TypedDict):
|
||||
content: str
|
||||
raw_content: Set[str]
|
||||
|
||||
|
||||
class JargonMeaningEntry(TypedDict):
|
||||
content: str
|
||||
meaning: str
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, session_id: str, session_name: str) -> None:
|
||||
self.session_id = session_id
|
||||
self.session_name = session_name
|
||||
|
||||
# Cache 相关
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
def get_cached_jargons(self) -> List[str]:
|
||||
"""获取缓存中的所有黑话列表"""
|
||||
return list(self.cache.keys())
|
||||
|
||||
async def infer_meaning(self, jargon_obj: MaiJargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""
|
||||
content = jargon_obj.content
|
||||
# 解析raw_content列表
|
||||
raw_content_list = []
|
||||
if raw_content_str := jargon_obj.raw_content:
|
||||
try:
|
||||
raw_content_list = json.loads(raw_content_str)
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||||
|
||||
if not raw_content_list:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
|
||||
# 获取当前count和上一次的meaning
|
||||
current_count = jargon_obj.count
|
||||
previous_meaning = jargon_obj.meaning
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
|
||||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||||
# 计算要保留的数量(至少保留1个)
|
||||
keep_count = max(1, len(raw_content_list) // 2)
|
||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||
logger.info(
|
||||
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
|
||||
)
|
||||
|
||||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||||
previous_meaning_section = ""
|
||||
previous_meaning_instruction = ""
|
||||
if current_count in [24, 60, 100] and previous_meaning:
|
||||
previous_meaning_section = f"\n**上一次推断的含义(仅供参考)**\n{previous_meaning}"
|
||||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
|
||||
prompt1_template = prompt_manager.get_prompt("jargon_inference_with_context")
|
||||
prompt1_template.add_context("bot_name", global_config.bot.nickname)
|
||||
prompt1_template.add_context("content", str(content))
|
||||
prompt1_template.add_context("raw_content_list", raw_content_text)
|
||||
prompt1_template.add_context("previous_meaning_section", previous_meaning_section)
|
||||
prompt1_template.add_context("previous_meaning_instruction", previous_meaning_instruction)
|
||||
prompt1 = await prompt_manager.render_prompt(prompt1_template)
|
||||
|
||||
llm_response_1, _ = await llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||||
if not llm_response_1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断1结果
|
||||
inference1 = self._parse_result(llm_response_1)
|
||||
if not inference1:
|
||||
logger.warning(f"jargon {content} 推断1解析失败")
|
||||
return
|
||||
|
||||
no_info = inference1.get("no_info", False)
|
||||
meaning1: str = inference1.get("meaning", "").strip()
|
||||
if no_info or not meaning1:
|
||||
logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新")
|
||||
# 更新最后一次判定的count值,避免在同一阈值重复尝试
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
try:
|
||||
self._modify_jargon_entry(jargon_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断1更新last_inference_count失败: {e}")
|
||||
return
|
||||
|
||||
# 步骤2: 基于content-only进行推断
|
||||
prompt2_template = prompt_manager.get_prompt("jargon_inference_content_only")
|
||||
prompt2_template.add_context("content", content)
|
||||
prompt2 = await prompt_manager.render_prompt(prompt2_template)
|
||||
|
||||
llm_response_2, _ = await llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||||
if not llm_response_2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断2结果
|
||||
inference2 = self._parse_result(llm_response_2)
|
||||
if not inference2:
|
||||
logger.warning(f"jargon {content} 推断2解析失败")
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
|
||||
# 步骤3: 比较两个推断结果
|
||||
prompt3_template = prompt_manager.get_prompt("jargon_compare_inference")
|
||||
prompt3_template.add_context("inference1", json.dumps(inference1, ensure_ascii=False))
|
||||
prompt3_template.add_context("inference2", json.dumps(inference2, ensure_ascii=False))
|
||||
prompt3 = await prompt_manager.render_prompt(prompt3_template)
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
llm_response_3, _ = await llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||||
if not llm_response_3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
comparison_result = self._parse_result(llm_response_3)
|
||||
if not comparison_result:
|
||||
logger.warning(f"jargon {content} 比较解析失败")
|
||||
return
|
||||
|
||||
is_similar = comparison_result.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
jargon_obj.meaning = inference1.get("meaning", "") if is_jargon else ""
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
|
||||
try:
|
||||
self._modify_jargon_entry(jargon_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断结果更新失败: {e}")
|
||||
logger.debug(
|
||||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||
)
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||||
meaning = jargon_obj.meaning or "无详细说明"
|
||||
is_global = jargon_obj.is_global # 是否为全局的
|
||||
if is_global:
|
||||
logger.info(f"[黑话]{content}的含义是 {meaning}")
|
||||
else:
|
||||
logger.info(f"[{self.session_name}]{content}的含义是 {meaning}")
|
||||
else:
|
||||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||
logger.info(f"[{self.session_name}]{content} 不是黑话")
|
||||
|
||||
async def process_extracted_entries(
|
||||
self, entries: List[JargonEntry], person_name_filter: Optional[Callable[[str], bool]] = None
|
||||
) -> None:
|
||||
"""
|
||||
处理已提取的黑话条目(从 expression_learner 路由过来的)
|
||||
|
||||
Args:
|
||||
entries: 黑话条目列表
|
||||
person_name_filter: 可选的过滤函数,用于检查内容是否包含人物名称
|
||||
"""
|
||||
if not entries:
|
||||
return
|
||||
merged_entries: Dict[str, JargonEntry] = {}
|
||||
for entry in entries:
|
||||
content = entry["content"].strip()
|
||||
|
||||
if person_name_filter and person_name_filter(content):
|
||||
logger.info(f"条目 '{content}' 包含人物名称,已过滤")
|
||||
continue
|
||||
raw_list = entry["raw_content"] or set()
|
||||
if content in merged_entries:
|
||||
merged_entries[content]["raw_content"].update(raw_list)
|
||||
else:
|
||||
merged_entries[content] = {"content": content, "raw_content": set(raw_list)}
|
||||
|
||||
uniq_entries: List[JargonEntry] = list(merged_entries.values())
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_set = entry["raw_content"]
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
jargon_items = session.exec(select(Jargon).filter_by(content=content)).all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询黑话 '{content}' 失败: {e}")
|
||||
continue
|
||||
# 找匹配项
|
||||
matched_jargon: Optional[Jargon] = None
|
||||
for item in jargon_items:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_jargon = item
|
||||
break
|
||||
else:
|
||||
# 检查列表是否包含目标session_id
|
||||
if item.session_id_dict:
|
||||
try:
|
||||
session_id_dict = json.loads(item.session_id_dict)
|
||||
if self.session_id in session_id_dict:
|
||||
matched_jargon = item
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"解析Jargon id={item.id} session_id_list失败: {e}")
|
||||
continue
|
||||
if matched_jargon:
|
||||
# 已存在记录,更新count和raw_content
|
||||
self._update_jargon(matched_jargon, raw_content_set)
|
||||
if self._should_infer_meaning(matched_jargon):
|
||||
asyncio.create_task(self._infer_meaning_by_id(matched_jargon.id)) # type: ignore
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
is_global_new = global_config.expression.all_global_jargon
|
||||
session_dict_str = json.dumps({self.session_id: 1})
|
||||
new_jargon = Jargon(
|
||||
content=content,
|
||||
raw_content=json.dumps(list(raw_content_set), ensure_ascii=False),
|
||||
session_id_dict=session_dict_str,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
meaning="",
|
||||
)
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
session.add(new_jargon)
|
||||
session.flush()
|
||||
saved += 1
|
||||
self._add_to_cache(content)
|
||||
except Exception as e:
|
||||
logger.error(f"保存新黑话 '{content}' 失败: {e}")
|
||||
continue
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
logger.info(f"[{self.session_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,session_id={self.session_id}")
|
||||
|
||||
def _add_to_cache(self, content: str):
|
||||
"""将黑话内容添加到缓存,并维护缓存大小"""
|
||||
content = content.strip()
|
||||
if is_single_char_jargon(content):
|
||||
return
|
||||
if content in self.cache:
|
||||
# 已存在,移动到末尾表示最近使用
|
||||
self.cache.move_to_end(content)
|
||||
else:
|
||||
# 新内容,添加到缓存
|
||||
self.cache[content] = None
|
||||
# 如果超过限制,移除最旧的项
|
||||
if len(self.cache) > self.cache_limit:
|
||||
removed_content, _ = self.cache.popitem(last=False)
|
||||
logger.debug(f"缓存已满,移除最旧的黑话: {removed_content}")
|
||||
|
||||
def _update_jargon(self, db_jargon: Jargon, raw_content_set: Set[str]) -> None:
|
||||
"""更新已有黑话记录并写回数据库。
|
||||
|
||||
Args:
|
||||
db_jargon: 已命中的黑话 ORM 对象。
|
||||
raw_content_set: 本次新增的原始上下文集合。
|
||||
"""
|
||||
db_jargon.count += 1
|
||||
existing_raw_content: List[str] = []
|
||||
if db_jargon.raw_content:
|
||||
try:
|
||||
existing_raw_content = json.loads(db_jargon.raw_content)
|
||||
except Exception:
|
||||
existing_raw_content = []
|
||||
|
||||
# 合并去重
|
||||
merged_list = list(set(existing_raw_content).union(raw_content_set))
|
||||
db_jargon.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
session_id_dict: Dict[str, int] = json.loads(db_jargon.session_id_dict)
|
||||
session_id_dict[self.session_id] = session_id_dict.get(self.session_id, 0) + 1
|
||||
db_jargon.session_id_dict = json.dumps(session_id_dict)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.expression.all_global_jargon:
|
||||
db_jargon.is_global = True
|
||||
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
if db_jargon.id is None:
|
||||
raise ValueError("黑话记录缺少 id,无法更新数据库")
|
||||
statement = select(Jargon).filter_by(id=db_jargon.id).limit(1)
|
||||
if persisted_jargon := session.exec(statement).first():
|
||||
persisted_jargon.count = db_jargon.count
|
||||
persisted_jargon.raw_content = db_jargon.raw_content
|
||||
persisted_jargon.session_id_dict = db_jargon.session_id_dict
|
||||
persisted_jargon.is_global = db_jargon.is_global
|
||||
session.add(persisted_jargon)
|
||||
else:
|
||||
logger.warning(f"黑话 ID {db_jargon.id} 在数据库中未找到,无法更新")
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话 '{db_jargon.content}' 失败: {e}")
|
||||
|
||||
def _parse_result(self, response: str) -> Optional[Dict[str, str]]:
|
||||
try:
|
||||
result = json.loads(response.strip())
|
||||
except Exception:
|
||||
try:
|
||||
repaired = repair_json(response.strip())
|
||||
result = json.loads(repaired)
|
||||
except Exception as e2:
|
||||
logger.error(f"推断结果解析失败: {e2}")
|
||||
return None
|
||||
if not isinstance(result, dict):
|
||||
logger.warning("推断结果格式错误")
|
||||
return None
|
||||
return result
|
||||
|
||||
def _modify_jargon_entry(self, jargon_obj: MaiJargon) -> None:
|
||||
with get_db_session() as session:
|
||||
if not jargon_obj.item_id:
|
||||
raise ValueError("jargon_obj must have item_id to update")
|
||||
statement = select(Jargon).filter_by(id=jargon_obj.item_id).limit(1)
|
||||
if db_record := session.exec(statement).first():
|
||||
db_record.is_jargon = jargon_obj.is_jargon
|
||||
db_record.meaning = jargon_obj.meaning
|
||||
db_record.last_inference_count = jargon_obj.last_inference_count
|
||||
db_record.is_complete = jargon_obj.is_complete
|
||||
session.add(db_record)
|
||||
|
||||
def _should_infer_meaning(self, jargon_obj: Jargon) -> bool:
|
||||
"""
|
||||
判断是否需要进行含义推断
|
||||
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
|
||||
并且count必须大于last_inference_count,避免重启后重复判定
|
||||
如果is_complete为True,不再进行推断
|
||||
"""
|
||||
# 如果已完成所有推断,不再推断
|
||||
if jargon_obj.is_complete:
|
||||
return False
|
||||
|
||||
count = jargon_obj.count or 0
|
||||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [2, 4, 8, 12, 24, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
# 如果count没有超过上次判定值,不需要判定
|
||||
if count <= last_inference:
|
||||
return False
|
||||
|
||||
next_threshold = next(
|
||||
(threshold for threshold in thresholds if threshold > last_inference),
|
||||
None,
|
||||
)
|
||||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||||
return False if next_threshold is None else count >= next_threshold
|
||||
|
||||
async def _infer_meaning_by_id(self, jargon_id: int):
|
||||
jargon_obj: Optional[MaiJargon] = None
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(Jargon).filter_by(id=jargon_id).limit(1)
|
||||
if db_record := session.exec(statement).first():
|
||||
jargon_obj = MaiJargon.from_db_instance(db_record)
|
||||
except Exception as e:
|
||||
logger.error(f"查询Jargon id={jargon_id}失败: {e}")
|
||||
return
|
||||
if jargon_obj:
|
||||
await self.infer_meaning(jargon_obj)
|
||||
134
src/learners/learner_utils.py
Normal file
134
src/learners/learner_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from json_repair import repair_json
|
||||
from typing import List, Tuple
|
||||
|
||||
import re
|
||||
import json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
if char == "\\":
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
if in_string and char in ["“", "”"]:
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
i += 1
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
{"content": "词条", "source_id": "12"}, // 黑话
|
||||
...
|
||||
]
|
||||
|
||||
Returns:
|
||||
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
第一个列表是表达方式 (situation, style, source_id)
|
||||
第二个列表是黑话 (content, source_id)
|
||||
"""
|
||||
if not response:
|
||||
return [], []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
if match := re.search(r"```json\s*(.*?)\s*```", raw, re.DOTALL):
|
||||
raw = match[1].strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
parsed = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return [], []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return [], []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# 检查是否是表达方式条目(有 situation 和 style)
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
if situation and style and source_id:
|
||||
# 表达方式条目
|
||||
expressions.append((situation, style, source_id))
|
||||
elif item.get("content"):
|
||||
# 黑话条目(有 content 字段)
|
||||
content = str(item.get("content", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if content and source_id:
|
||||
jargon_entries.append((content, source_id))
|
||||
|
||||
return expressions, jargon_entries
|
||||
445
src/learners/learner_utils_old.py
Normal file
445
src/learners/learner_utils_old.py
Normal file
@@ -0,0 +1,445 @@
|
||||
import random
|
||||
import json
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
weights = [1.0 for _ in counts]
|
||||
else:
|
||||
weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
parsed = json.loads(chat_id_value)
|
||||
if isinstance(parsed, list):
|
||||
# 新格式:已经是列表
|
||||
return parsed
|
||||
elif isinstance(parsed, str):
|
||||
# 解析后还是字符串,说明是旧格式
|
||||
return [[parsed, 1]]
|
||||
else:
|
||||
# 其他类型,当作旧格式处理
|
||||
return [[str(chat_id_value), 1]]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 解析失败,当作旧格式(纯字符串)
|
||||
return [[str(chat_id_value), 1]]
|
||||
elif isinstance(chat_id_value, list):
|
||||
# 已经是列表格式
|
||||
return chat_id_value
|
||||
else:
|
||||
# 其他类型,转换为旧格式
|
||||
return [[str(chat_id_value), 1]]
|
||||
|
||||
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
item = _find_chat_id_item(chat_id_list, target_chat_id)
|
||||
if item is not None:
|
||||
# 找到匹配的chat_id,增加计数
|
||||
if len(item) >= 2:
|
||||
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
|
||||
else:
|
||||
item.append(increment)
|
||||
else:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
在chat_id列表中查找匹配的项(辅助函数)
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
如果找到则返回匹配的项,否则返回None
|
||||
"""
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
|
||||
|
||||
|
||||
def contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身。"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
known_accounts = {
|
||||
str(getattr(bot_config, "qq_account", "") or "").strip(),
|
||||
str(getattr(bot_config, "telegram_account", "") or "").strip(),
|
||||
}
|
||||
|
||||
for platform in getattr(bot_config, "platforms", []) or []:
|
||||
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
|
||||
if account:
|
||||
known_accounts.add(account)
|
||||
|
||||
return user_id in {account for account in known_accounts if account}
|
||||
|
||||
|
||||
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
# """
|
||||
# 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
# """
|
||||
# if not messages or center_index < 0 or center_index >= len(messages):
|
||||
# return None
|
||||
|
||||
# context_start = max(0, center_index - 3)
|
||||
# context_end = min(len(messages), center_index + 1 + 3)
|
||||
# context_messages = messages[context_start:context_end]
|
||||
|
||||
# if not context_messages:
|
||||
# return None
|
||||
|
||||
# try:
|
||||
# paragraph = build_readable_messages(
|
||||
# messages=context_messages,
|
||||
# replace_bot_name=True,
|
||||
# timestamp_mode="relative",
|
||||
# read_mark=0.0,
|
||||
# truncate=False,
|
||||
# show_actions=False,
|
||||
# show_pic=True,
|
||||
# message_id_list=None,
|
||||
# remove_emoji_stickers=False,
|
||||
# pic_single=True,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"构建上下文段落失败: {e}")
|
||||
# return None
|
||||
|
||||
# paragraph = paragraph.strip()
|
||||
# return paragraph or None
|
||||
|
||||
|
||||
# def is_bot_message(msg: Any) -> bool:
|
||||
# """判断消息是否来自机器人自身"""
|
||||
# if msg is None:
|
||||
# return False
|
||||
|
||||
# bot_config = getattr(global_config, "bot", None)
|
||||
# if not bot_config:
|
||||
# return False
|
||||
|
||||
# platform = (
|
||||
# str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
|
||||
# .strip()
|
||||
# .lower()
|
||||
# )
|
||||
# user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
# if not platform or not user_id:
|
||||
# return False
|
||||
|
||||
# platform_accounts = {}
|
||||
# try:
|
||||
# platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
|
||||
# except Exception:
|
||||
# platform_accounts = {}
|
||||
|
||||
# bot_accounts: Dict[str, str] = {}
|
||||
# qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
|
||||
# if qq_account:
|
||||
# bot_accounts["qq"] = qq_account
|
||||
|
||||
# telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
|
||||
# if telegram_account:
|
||||
# bot_accounts["telegram"] = telegram_account
|
||||
|
||||
# for plat, account in platform_accounts.items():
|
||||
# if account and plat not in bot_accounts:
|
||||
# bot_accounts[plat] = account
|
||||
|
||||
# bot_account = bot_accounts.get(platform)
|
||||
# return bool(bot_account and user_id == bot_account)
|
||||
|
||||
|
||||
# def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
# """
|
||||
# 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。
|
||||
|
||||
# 期望的 JSON 结构:
|
||||
# [
|
||||
# {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
|
||||
# {"content": "词条", "source_id": "12"}, // 黑话
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# Returns:
|
||||
# Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
|
||||
# 第一个列表是表达方式 (situation, style, source_id)
|
||||
# 第二个列表是黑话 (content, source_id)
|
||||
# """
|
||||
# if not response:
|
||||
# return [], []
|
||||
|
||||
# raw = response.strip()
|
||||
|
||||
# # 尝试提取 ```json 代码块
|
||||
# json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
# match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
# if match:
|
||||
# raw = match.group(1).strip()
|
||||
# else:
|
||||
# # 去掉可能存在的通用 ``` 包裹
|
||||
# raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
# raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
# raw = raw.strip()
|
||||
|
||||
# parsed = None
|
||||
# expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
|
||||
# jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
|
||||
|
||||
# try:
|
||||
# # 优先尝试直接解析
|
||||
# if raw.startswith("[") and raw.endswith("]"):
|
||||
# parsed = json.loads(raw)
|
||||
# else:
|
||||
# repaired = repair_json(raw)
|
||||
# if isinstance(repaired, str):
|
||||
# parsed = json.loads(repaired)
|
||||
# else:
|
||||
# parsed = repaired
|
||||
# except Exception as parse_error:
|
||||
# # 如果解析失败,尝试修复中文引号问题
|
||||
# # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
# try:
|
||||
|
||||
# def fix_chinese_quotes_in_json(text):
|
||||
# """使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
# result = []
|
||||
# i = 0
|
||||
# in_string = False
|
||||
# escape_next = False
|
||||
|
||||
# while i < len(text):
|
||||
# char = text[i]
|
||||
|
||||
# if escape_next:
|
||||
# # 当前字符是转义字符后的字符,直接添加
|
||||
# result.append(char)
|
||||
# escape_next = False
|
||||
# i += 1
|
||||
# continue
|
||||
|
||||
# if char == "\\":
|
||||
# # 转义字符
|
||||
# result.append(char)
|
||||
# escape_next = True
|
||||
# i += 1
|
||||
# continue
|
||||
|
||||
# if char == '"' and not escape_next:
|
||||
# # 遇到英文引号,切换字符串状态
|
||||
# in_string = not in_string
|
||||
# result.append(char)
|
||||
# i += 1
|
||||
# continue
|
||||
|
||||
# if in_string:
|
||||
# # 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
# if char == '"': # 中文左引号 U+201C
|
||||
# result.append('\\"')
|
||||
# elif char == '"': # 中文右引号 U+201D
|
||||
# result.append('\\"')
|
||||
# else:
|
||||
# result.append(char)
|
||||
# else:
|
||||
# # 不在字符串内,直接添加
|
||||
# result.append(char)
|
||||
|
||||
# i += 1
|
||||
|
||||
# return "".join(result)
|
||||
|
||||
# fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# # 再次尝试解析
|
||||
# if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
# parsed = json.loads(fixed_raw)
|
||||
# else:
|
||||
# repaired = repair_json(fixed_raw)
|
||||
# if isinstance(repaired, str):
|
||||
# parsed = json.loads(repaired)
|
||||
# else:
|
||||
# parsed = repaired
|
||||
# except Exception as fix_error:
|
||||
# logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
# logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
# logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
# logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
# return [], []
|
||||
|
||||
# if isinstance(parsed, dict):
|
||||
# parsed_list = [parsed]
|
||||
# elif isinstance(parsed, list):
|
||||
# parsed_list = parsed
|
||||
# else:
|
||||
# logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
# return [], []
|
||||
|
||||
# for item in parsed_list:
|
||||
# if not isinstance(item, dict):
|
||||
# continue
|
||||
|
||||
# # 检查是否是表达方式条目(有 situation 和 style)
|
||||
# situation = str(item.get("situation", "")).strip()
|
||||
# style = str(item.get("style", "")).strip()
|
||||
# source_id = str(item.get("source_id", "")).strip()
|
||||
|
||||
# if situation and style and source_id:
|
||||
# # 表达方式条目
|
||||
# expressions.append((situation, style, source_id))
|
||||
# elif item.get("content"):
|
||||
# # 黑话条目(有 content 字段)
|
||||
# content = str(item.get("content", "")).strip()
|
||||
# source_id = str(item.get("source_id", "")).strip()
|
||||
# if content and source_id:
|
||||
# jargon_entries.append((content, source_id))
|
||||
|
||||
# return expressions, jargon_entries
|
||||
Reference in New Issue
Block a user