feat:复用jargon和expression的部分代码,代码层面合并,合并配置项
缓解bot重复学习自身表达的问题 缓解单字黑话推断时消耗过高的问题 修复count过高时推断过长的问题 移除表达方式学习强度配置
This commit is contained in:
474
src/bw_learner/expression_learner.py
Normal file
474
src/bw_learner/expression_learner.py
Normal file
@@ -0,0 +1,474 @@
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_anonymous_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.bw_learner.learner_utils import filter_message_content, is_bot_message
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
# MAX_EXPRESSION_COUNT = 300
|
||||
|
||||
logger = get_logger("expressor")
|
||||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """{chat_str}
|
||||
你的名字是{bot_name},现在请你请从上面这段群聊中用户的语言风格和说话方式
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要总结SELF的发言
|
||||
3. 不要涉及具体的人名,也不要涉及具体名词
|
||||
4. 思考有没有特殊的梗,一并总结成语言风格
|
||||
5. 例子仅供参考,请严格根据群聊内容总结!!!
|
||||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
请严格以 JSON 数组的形式输出结果,每个元素为一个对象,结构如下(注意字段名):
|
||||
[
|
||||
{{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
|
||||
{{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
|
||||
{{"situation": "对某件事表示十分惊叹", "style": "使用 我嘞个xxxx", "source_id": "[消息编号]"}},
|
||||
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
|
||||
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}},
|
||||
]
|
||||
|
||||
其中:
|
||||
- situation:表示“在什么情境下”的简短概括(不超过20个字)
|
||||
- style:表示对应的语言风格或常用表达(不超过20个字)
|
||||
- source_id:该表达方式对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
|
||||
现在请你输出 JSON:
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils, request_type="expression.learner"
|
||||
)
|
||||
self.summary_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 学习锁,防止并发执行学习任务
|
||||
self._learning_lock = asyncio.Lock()
|
||||
|
||||
async def learn_and_store(
|
||||
self,
|
||||
messages: List[Any],
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
学习并存储表达方式
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
num: 学习数量
|
||||
timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
random_msg = messages
|
||||
|
||||
# 学习用(开启行编号,便于溯源)
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=random_msg_str,
|
||||
)
|
||||
|
||||
# print(f"random_msg_str:{random_msg_str}")
|
||||
# logger.info(f"学习{type_str}的prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
|
||||
# 解析 LLM 返回的表达方式列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return None
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
# 无效的来源行编号,跳过
|
||||
continue
|
||||
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
|
||||
# 过滤掉从bot自己发言中提取到的表达方式
|
||||
if is_bot_message(current_msg):
|
||||
continue
|
||||
|
||||
context = filter_message_content(current_msg.processed_plain_text or "")
|
||||
if not context:
|
||||
continue
|
||||
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
|
||||
learnt_expressions = filtered_expressions
|
||||
|
||||
if learnt_expressions is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
|
||||
# 展示学到的表达方式
|
||||
learnt_expressions_str = ""
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
_context,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 存储到数据库 Expression 表
|
||||
for (
|
||||
situation,
|
||||
style,
|
||||
context,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析 LLM 返回的表达风格总结 JSON,提取 (situation, style, source_id) 元组列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"},
|
||||
...
|
||||
]
|
||||
"""
|
||||
if not response:
|
||||
return []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as parse_error:
|
||||
# 如果解析失败,尝试修复中文引号问题
|
||||
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
|
||||
try:
|
||||
def fix_chinese_quotes_in_json(text):
|
||||
"""使用状态机修复 JSON 字符串值中的中文引号"""
|
||||
result = []
|
||||
i = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
|
||||
while i < len(text):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
# 当前字符是转义字符后的字符,直接添加
|
||||
result.append(char)
|
||||
escape_next = False
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '\\':
|
||||
# 转义字符
|
||||
result.append(char)
|
||||
escape_next = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
# 遇到英文引号,切换字符串状态
|
||||
in_string = not in_string
|
||||
result.append(char)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if in_string:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
if char == '"': # 中文左引号 U+201C
|
||||
result.append('\\"')
|
||||
elif char == '"': # 中文右引号 U+201D
|
||||
result.append('\\"')
|
||||
else:
|
||||
result.append(char)
|
||||
else:
|
||||
# 不在字符串内,直接添加
|
||||
result.append(char)
|
||||
|
||||
i += 1
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
fixed_raw = fix_chinese_quotes_in_json(raw)
|
||||
|
||||
# 再次尝试解析
|
||||
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
|
||||
parsed = json.loads(fixed_raw)
|
||||
else:
|
||||
repaired = repair_json(fixed_raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception as fix_error:
|
||||
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
|
||||
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}")
|
||||
return []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if not situation or not style or not source_id:
|
||||
# 三个字段必须同时存在
|
||||
continue
|
||||
expressions.append((situation, style, source_id))
|
||||
|
||||
return expressions
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
banned_names = set()
|
||||
bot_nickname = (global_config.bot.nickname or "").strip()
|
||||
if bot_nickname:
|
||||
banned_names.add(bot_nickname)
|
||||
|
||||
alias_names = global_config.bot.alias_names or []
|
||||
for alias in alias_names:
|
||||
alias = alias.strip()
|
||||
if alias:
|
||||
banned_names.add(alias)
|
||||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
if removed_count:
|
||||
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
|
||||
|
||||
return filtered
|
||||
|
||||
async def _upsert_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
||||
if expr_obj:
|
||||
await self._update_existing_expression(
|
||||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
return
|
||||
|
||||
await self._create_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
async def _create_expression_record(
|
||||
self,
|
||||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
|
||||
|
||||
Expression.create(
|
||||
situation=formatted_situation,
|
||||
style=style,
|
||||
content_list=json.dumps(content_list, ensure_ascii=False),
|
||||
count=1,
|
||||
last_active_time=current_time,
|
||||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
self,
|
||||
expr_obj: Expression,
|
||||
situation: str,
|
||||
context: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
content_list.append(situation)
|
||||
|
||||
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
|
||||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
count=expr_obj.count,
|
||||
fallback=expr_obj.situation,
|
||||
)
|
||||
expr_obj.situation = new_situation
|
||||
|
||||
expr_obj.save()
|
||||
|
||||
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
|
||||
if not stored_list:
|
||||
return []
|
||||
try:
|
||||
data = json.loads(stored_list)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
|
||||
|
||||
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
|
||||
sanitized = [c.strip() for c in content_list if c.strip()]
|
||||
summary = await self._summarize_situations(sanitized)
|
||||
if summary:
|
||||
return summary
|
||||
return "/".join(sanitized) if sanitized else fallback
|
||||
|
||||
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
|
||||
if not situations:
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
|
||||
"长度不超过20个字,保留共同特点:\n"
|
||||
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
|
||||
summary = summary.strip()
|
||||
if summary:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
class ExpressionLearnerManager:
|
||||
def __init__(self):
|
||||
self.expression_learners = {}
|
||||
|
||||
self._ensure_expression_directories()
|
||||
|
||||
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
|
||||
if chat_id not in self.expression_learners:
|
||||
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
|
||||
return self.expression_learners[chat_id]
|
||||
|
||||
def _ensure_expression_directories(self):
|
||||
"""
|
||||
确保表达方式相关的目录结构存在
|
||||
"""
|
||||
base_dir = os.path.join("data", "expression")
|
||||
directories_to_create = [
|
||||
base_dir,
|
||||
os.path.join(base_dir, "learnt_style"),
|
||||
os.path.join(base_dir, "learnt_grammar"),
|
||||
]
|
||||
|
||||
for directory in directories_to_create:
|
||||
try:
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.debug(f"确保目录存在: {directory}")
|
||||
except Exception as e:
|
||||
logger.error(f"创建目录失败 {directory}: {e}")
|
||||
|
||||
|
||||
expression_learner_manager = ExpressionLearnerManager()
|
||||
252
src/bw_learner/expression_reflector.py
Normal file
252
src/bw_learner/expression_reflector.py
Normal file
@@ -0,0 +1,252 @@
|
||||
import random
|
||||
import time
|
||||
from typing import Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("expression_reflector")
|
||||
|
||||
|
||||
class ExpressionReflector:
|
||||
"""表达反思器,管理单个聊天流的表达反思提问"""
|
||||
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.last_ask_time: float = 0.0
|
||||
|
||||
async def check_and_ask(self) -> bool:
|
||||
"""
|
||||
检查是否需要提问表达反思,如果需要则提问
|
||||
|
||||
Returns:
|
||||
bool: 是否执行了提问
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||
|
||||
if not global_config.expression.reflect:
|
||||
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
|
||||
return False
|
||||
|
||||
operator_config = global_config.expression.reflect_operator_id
|
||||
if not operator_config:
|
||||
logger.debug("[Expression Reflection] Operator ID 未配置,跳过")
|
||||
return False
|
||||
|
||||
# 检查是否在允许列表中
|
||||
allow_reflect = global_config.expression.allow_reflect
|
||||
if allow_reflect:
|
||||
# 将 allow_reflect 中的 platform:id:type 格式转换为 chat_id 列表
|
||||
allow_reflect_chat_ids = []
|
||||
for stream_config in allow_reflect:
|
||||
parsed_chat_id = global_config.expression._parse_stream_config_to_chat_id(stream_config)
|
||||
if parsed_chat_id:
|
||||
allow_reflect_chat_ids.append(parsed_chat_id)
|
||||
else:
|
||||
logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}")
|
||||
|
||||
if self.chat_id not in allow_reflect_chat_ids:
|
||||
logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过")
|
||||
return False
|
||||
|
||||
# 检查上一次提问时间
|
||||
current_time = time.time()
|
||||
time_since_last_ask = current_time - self.last_ask_time
|
||||
|
||||
# 5-10分钟间隔,随机选择
|
||||
min_interval = 10 * 60 # 5分钟
|
||||
max_interval = 15 * 60 # 10分钟
|
||||
interval = random.uniform(min_interval, max_interval)
|
||||
|
||||
logger.info(
|
||||
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
|
||||
)
|
||||
|
||||
if time_since_last_ask < interval:
|
||||
remaining_time = interval - time_since_last_ask
|
||||
logger.info(
|
||||
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
|
||||
)
|
||||
return False
|
||||
|
||||
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
||||
logger.info(f"[Expression Reflection] 检查 Operator {operator_config} 是否已有活跃的 Tracker")
|
||||
if await _check_tracker_exists(operator_config):
|
||||
logger.info(f"[Expression Reflection] Operator {operator_config} 已有活跃的 Tracker,跳过本次提问")
|
||||
return False
|
||||
|
||||
# 获取未检查的表达
|
||||
try:
|
||||
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||
expressions = (
|
||||
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||
)
|
||||
|
||||
expr_list = list(expressions)
|
||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||
|
||||
if not expr_list:
|
||||
logger.info("[Expression Reflection] 没有可用的表达,跳过")
|
||||
return False
|
||||
|
||||
target_expr: Expression = random.choice(expr_list)
|
||||
logger.info(
|
||||
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
|
||||
)
|
||||
|
||||
# 生成询问文本
|
||||
ask_text = _generate_ask_text(target_expr)
|
||||
if not ask_text:
|
||||
logger.warning("[Expression Reflection] 生成询问文本失败,跳过")
|
||||
return False
|
||||
|
||||
logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问")
|
||||
# 发送给 Operator
|
||||
await _send_to_operator(operator_config, ask_text, target_expr)
|
||||
|
||||
# 更新上一次提问时间
|
||||
self.last_ask_time = current_time
|
||||
logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
|
||||
class ExpressionReflectorManager:
|
||||
"""表达反思管理器,管理多个聊天流的表达反思实例"""
|
||||
|
||||
def __init__(self):
|
||||
self.reflectors: Dict[str, ExpressionReflector] = {}
|
||||
|
||||
def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector:
|
||||
"""获取或创建指定聊天流的表达反思实例"""
|
||||
if chat_id not in self.reflectors:
|
||||
self.reflectors[chat_id] = ExpressionReflector(chat_id)
|
||||
return self.reflectors[chat_id]
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
expression_reflector_manager = ExpressionReflectorManager()
|
||||
|
||||
|
||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||
from src.express.reflect_tracker import reflect_tracker_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = None
|
||||
|
||||
# 尝试解析配置字符串 "platform:id:type"
|
||||
parts = operator_config.split(":")
|
||||
if len(parts) == 3:
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
user_info = None
|
||||
group_info = None
|
||||
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
if stream_type == "group":
|
||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||
elif stream_type == "private":
|
||||
user_info = UserInfo(user_id=id_str, platform=platform, user_nickname="Operator")
|
||||
else:
|
||||
return False
|
||||
|
||||
if user_info:
|
||||
try:
|
||||
chat_stream = await chat_manager.get_or_create_stream(platform, user_info, group_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get or create chat stream for checking tracker: {e}")
|
||||
return False
|
||||
else:
|
||||
chat_stream = chat_manager.get_stream(operator_config)
|
||||
|
||||
if not chat_stream:
|
||||
return False
|
||||
|
||||
return reflect_tracker_manager.get_tracker(chat_stream.stream_id) is not None
|
||||
|
||||
|
||||
def _generate_ask_text(expr: Expression) -> Optional[str]:
|
||||
try:
|
||||
ask_text = (
|
||||
f"我正在学习新的表达方式,请帮我看看这个是否合适?\n\n"
|
||||
f"**学习到的表达信息**\n"
|
||||
f"- 情景 (Situation): {expr.situation}\n"
|
||||
f"- 风格 (Style): {expr.style}\n"
|
||||
)
|
||||
return ask_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate ask text: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = None
|
||||
|
||||
# 尝试解析配置字符串 "platform:id:type"
|
||||
parts = operator_config.split(":")
|
||||
if len(parts) == 3:
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
|
||||
user_info = None
|
||||
group_info = None
|
||||
|
||||
from maim_message import UserInfo, GroupInfo
|
||||
|
||||
if stream_type == "group":
|
||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||
elif stream_type == "private":
|
||||
user_info = UserInfo(user_id=id_str, platform=platform, user_nickname="Operator")
|
||||
else:
|
||||
logger.warning(f"Unknown stream type in operator config: {stream_type}")
|
||||
return
|
||||
|
||||
if user_info:
|
||||
try:
|
||||
chat_stream = await chat_manager.get_or_create_stream(platform, user_info, group_info)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get or create chat stream for operator {operator_config}: {e}")
|
||||
return
|
||||
else:
|
||||
chat_stream = chat_manager.get_stream(operator_config)
|
||||
|
||||
if not chat_stream:
|
||||
logger.warning(f"Could not find or create chat stream for operator: {operator_config}")
|
||||
return
|
||||
|
||||
stream_id = chat_stream.stream_id
|
||||
|
||||
# 注册 Tracker
|
||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||
|
||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||
|
||||
# 发送消息
|
||||
await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
|
||||
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
||||
334
src/bw_learner/expression_selector.py
Normal file
334
src/bw_learner/expression_selector.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.learner_utils import weighted_sample
|
||||
|
||||
logger = get_logger("expression_selector")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
expression_evaluation_prompt = """{chat_observe_info}
|
||||
|
||||
你的名字是{bot_name}{target_message}
|
||||
{reply_reason_block}
|
||||
|
||||
以下是可选的表达情境:
|
||||
{all_situations}
|
||||
|
||||
请你分析聊天内容的语境、情绪、话题类型,从上述情境中选择最适合当前聊天情境的,最多{max_num}个情境。
|
||||
考虑因素包括:
|
||||
1.聊天的情绪氛围(轻松、严肃、幽默等)
|
||||
2.话题类型(日常、技术、游戏、情感等)
|
||||
3.情境与当前语境的匹配度
|
||||
{target_message_extra_block}
|
||||
|
||||
请以JSON格式输出,只需要输出选中的情境编号:
|
||||
例如:
|
||||
{{
|
||||
"selected_situations": [2, 3, 5, 7, 19]
|
||||
}}
|
||||
|
||||
请严格按照JSON格式输出,不要包含其他内容:
|
||||
"""
|
||||
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
|
||||
|
||||
|
||||
class ExpressionSelector:
|
||||
def __init__(self):
|
||||
self.llm_model = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
|
||||
)
|
||||
|
||||
def can_use_expression_for_chat(self, chat_id: str) -> bool:
|
||||
"""
|
||||
检查指定聊天流是否允许使用表达
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许使用表达
|
||||
"""
|
||||
try:
|
||||
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(chat_id)
|
||||
return use_expression
|
||||
except Exception as e:
|
||||
logger.error(f"检查表达使用权限失败: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
"""解析'platform:id:type'为chat_id(与get_stream_id一致)"""
|
||||
try:
|
||||
parts = stream_config_str.split(":")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
platform = parts[0]
|
||||
id_str = parts[1]
|
||||
stream_type = parts[2]
|
||||
is_group = stream_type == "group"
|
||||
if is_group:
|
||||
components = [platform, str(id_str)]
|
||||
else:
|
||||
components = [platform, str(id_str), "private"]
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_related_chat_ids(self, chat_id: str) -> List[str]:
|
||||
"""根据expression_groups配置,获取与当前chat_id相关的所有chat_id(包括自身)"""
|
||||
groups = global_config.expression.expression_groups
|
||||
|
||||
# 检查是否存在全局共享组(包含"*"的组)
|
||||
global_group_exists = any("*" in group for group in groups)
|
||||
|
||||
if global_group_exists:
|
||||
# 如果存在全局共享组,则返回所有可用的chat_id
|
||||
all_chat_ids = set()
|
||||
for group in groups:
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
all_chat_ids.add(chat_id_candidate)
|
||||
return list(all_chat_ids) if all_chat_ids else [chat_id]
|
||||
|
||||
# 否则使用现有的组逻辑
|
||||
for group in groups:
|
||||
group_chat_ids = []
|
||||
for stream_config_str in group:
|
||||
if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
|
||||
group_chat_ids.append(chat_id_candidate)
|
||||
if chat_id in group_chat_ids:
|
||||
return group_chat_ids
|
||||
return [chat_id]
|
||||
|
||||
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择表达方式
|
||||
|
||||
Args:
|
||||
chat_id: 聊天室ID
|
||||
total_num: 需要选择的数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 随机选择的表达方式列表
|
||||
"""
|
||||
try:
|
||||
# 支持多chat_id合并抽选
|
||||
related_chat_ids = self.get_related_chat_ids(chat_id)
|
||||
|
||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||
style_query = Expression.select().where(
|
||||
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||
)
|
||||
|
||||
style_exprs = [
|
||||
{
|
||||
"id": expr.id,
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"last_active_time": expr.last_active_time,
|
||||
"source_id": expr.chat_id,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
"count": expr.count if getattr(expr, "count", None) is not None else 1,
|
||||
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
|
||||
}
|
||||
for expr in style_query
|
||||
]
|
||||
|
||||
# 随机抽样
|
||||
if style_exprs:
|
||||
selected_style = weighted_sample(style_exprs, total_num)
|
||||
else:
|
||||
selected_style = []
|
||||
|
||||
return selected_style
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"随机选择表达方式失败: {e}")
|
||||
return []
|
||||
|
||||
async def select_suitable_expressions(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
选择适合的表达方式(使用classic模式:随机选择+LLM选择)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
if not self.can_use_expression_for_chat(chat_id):
|
||||
logger.debug(f"聊天流 {chat_id} 不允许使用表达,返回空列表")
|
||||
return [], []
|
||||
|
||||
# 使用classic模式(随机选择+LLM选择)
|
||||
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
|
||||
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason)
|
||||
|
||||
async def _select_expressions_classic(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_info: str,
|
||||
max_num: int = 10,
|
||||
target_message: Optional[str] = None,
|
||||
reply_reason: Optional[str] = None,
|
||||
) -> Tuple[List[Dict[str, Any]], List[int]]:
|
||||
"""
|
||||
classic模式:随机选择+LLM选择
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
chat_info: 聊天内容信息
|
||||
max_num: 最大选择数量
|
||||
target_message: 目标消息内容
|
||||
reply_reason: planner给出的回复理由
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
|
||||
"""
|
||||
try:
|
||||
# 1. 使用随机抽样选择表达方式
|
||||
style_exprs = self._random_expressions(chat_id, 20)
|
||||
|
||||
if len(style_exprs) < 10:
|
||||
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
|
||||
return [], []
|
||||
|
||||
# 2. 构建所有表达方式的索引和情境列表
|
||||
all_expressions: List[Dict[str, Any]] = []
|
||||
all_situations: List[str] = []
|
||||
|
||||
# 添加style表达方式
|
||||
for expr in style_exprs:
|
||||
expr = expr.copy()
|
||||
all_expressions.append(expr)
|
||||
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
|
||||
|
||||
if not all_expressions:
|
||||
logger.warning("没有找到可用的表达方式")
|
||||
return [], []
|
||||
|
||||
all_situations_str = "\n".join(all_situations)
|
||||
|
||||
if target_message:
|
||||
target_message_str = f",现在你想要对这条消息进行回复:“{target_message}”"
|
||||
target_message_extra_block = "4.考虑你要回复的目标消息"
|
||||
else:
|
||||
target_message_str = ""
|
||||
target_message_extra_block = ""
|
||||
|
||||
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
|
||||
|
||||
# 构建reply_reason块
|
||||
if reply_reason:
|
||||
reply_reason_block = f"你的回复理由是:{reply_reason}"
|
||||
chat_context = ""
|
||||
else:
|
||||
reply_reason_block = ""
|
||||
|
||||
# 3. 构建prompt(只包含情境,不包含完整的表达方式)
|
||||
prompt = (await global_prompt_manager.get_prompt_async("expression_evaluation_prompt")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_observe_info=chat_context,
|
||||
all_situations=all_situations_str,
|
||||
max_num=max_num,
|
||||
target_message=target_message_str,
|
||||
target_message_extra_block=target_message_extra_block,
|
||||
reply_reason_block=reply_reason_block,
|
||||
)
|
||||
|
||||
# 4. 调用LLM
|
||||
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
if not content:
|
||||
logger.warning("LLM返回空结果")
|
||||
return [], []
|
||||
|
||||
# 5. 解析结果
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
|
||||
if not isinstance(result, dict) or "selected_situations" not in result:
|
||||
logger.error("LLM返回格式错误")
|
||||
logger.info(f"LLM返回结果: \n{content}")
|
||||
return [], []
|
||||
|
||||
selected_indices = result["selected_situations"]
|
||||
|
||||
# 根据索引获取完整的表达方式
|
||||
valid_expressions: List[Dict[str, Any]] = []
|
||||
selected_ids = []
|
||||
for idx in selected_indices:
|
||||
if isinstance(idx, int) and 1 <= idx <= len(all_expressions):
|
||||
expression = all_expressions[idx - 1] # 索引从1开始
|
||||
selected_ids.append(expression["id"])
|
||||
valid_expressions.append(expression)
|
||||
|
||||
# 对选中的所有表达方式,更新last_active_time
|
||||
if valid_expressions:
|
||||
self.update_expressions_last_active_time(valid_expressions)
|
||||
|
||||
logger.debug(f"从{len(all_expressions)}个情境中选择了{len(valid_expressions)}个")
|
||||
return valid_expressions, selected_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"classic模式处理表达方式选择时出错: {e}")
|
||||
return [], []
|
||||
|
||||
def update_expressions_last_active_time(self, expressions_to_update: List[Dict[str, Any]]):
|
||||
"""对一批表达方式更新last_active_time"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
updates_by_key = {}
|
||||
for expr in expressions_to_update:
|
||||
source_id: str = expr.get("source_id") # type: ignore
|
||||
situation: str = expr.get("situation") # type: ignore
|
||||
style: str = expr.get("style") # type: ignore
|
||||
if not source_id or not situation or not style:
|
||||
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
|
||||
continue
|
||||
key = (source_id, situation, style)
|
||||
if key not in updates_by_key:
|
||||
updates_by_key[key] = expr
|
||||
for chat_id, situation, style in updates_by_key:
|
||||
query = Expression.select().where(
|
||||
(Expression.chat_id == chat_id) & (Expression.situation == situation) & (Expression.style == style)
|
||||
)
|
||||
if query.exists():
|
||||
expr_obj = query.get()
|
||||
expr_obj.last_active_time = time.time()
|
||||
expr_obj.save()
|
||||
logger.debug("表达方式激活: 更新last_active_time in db")
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
try:
|
||||
expression_selector = ExpressionSelector()
|
||||
except Exception as e:
|
||||
logger.error(f"ExpressionSelector初始化失败: {e}")
|
||||
360
src/bw_learner/jargon_explainer.py
Normal file
360
src/bw_learner/jargon_explainer.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import re
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.jargon_miner import search_jargon
|
||||
from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _init_explainer_prompts() -> None:
|
||||
"""初始化黑话解释器相关的prompt"""
|
||||
# Prompt:概括黑话解释结果
|
||||
summarize_prompt_str = """上下文聊天内容:
|
||||
{chat_context}
|
||||
|
||||
在上下文中提取到的黑话及其含义:
|
||||
{jargon_explanations}
|
||||
|
||||
请根据上述信息,对黑话解释进行概括和整理。
|
||||
- 如果上下文中有黑话出现,请简要说明这些黑话在上下文中的使用情况
|
||||
- 将所有黑话解释整理成简洁、易读的一段话
|
||||
- 输出格式要自然,适合作为回复参考信息
|
||||
请输出概括后的黑话解释(直接输出一段平文本,不要标题,无特殊格式或markdown格式,不要使用JSON格式):
|
||||
"""
|
||||
Prompt(summarize_prompt_str, "jargon_explainer_summarize_prompt")
|
||||
|
||||
|
||||
_init_explainer_prompts()
|
||||
|
||||
|
||||
class JargonExplainer:
|
||||
"""黑话解释器,用于在回复前识别和解释上下文中的黑话"""
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.explain",
|
||||
)
|
||||
|
||||
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
通过直接匹配数据库中的jargon字符串来提取黑话
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 收集所有消息的文本内容
|
||||
message_texts: List[str] = []
|
||||
for msg in messages:
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_message(msg):
|
||||
continue
|
||||
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if msg_text:
|
||||
message_texts.append(msg_text)
|
||||
|
||||
if not message_texts:
|
||||
return []
|
||||
|
||||
# 合并所有消息文本
|
||||
combined_text = " ".join(message_texts)
|
||||
|
||||
# 查询所有有meaning的jargon记录
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:只查询is_global=True的记录
|
||||
query = query.where(Jargon.is_global)
|
||||
else:
|
||||
# 关闭all_global:查询is_global=True或chat_id列表包含当前chat_id的记录
|
||||
# 这里先查询所有,然后在Python层面过滤
|
||||
pass
|
||||
|
||||
# 按count降序排序,优先匹配出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
# 执行查询并匹配
|
||||
matched_jargon: Dict[str, Dict[str, str]] = {}
|
||||
query_time = time.time()
|
||||
|
||||
for jargon in query:
|
||||
content = jargon.content or ""
|
||||
if not content or not content.strip():
|
||||
continue
|
||||
|
||||
# 跳过包含机器人昵称的词条
|
||||
if contains_bot_self_name(content):
|
||||
continue
|
||||
|
||||
# 检查chat_id(如果all_global=False)
|
||||
if not global_config.expression.all_global_jargon:
|
||||
if jargon.is_global:
|
||||
# 全局黑话,包含
|
||||
pass
|
||||
else:
|
||||
# 检查chat_id列表是否包含当前chat_id
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
continue
|
||||
|
||||
# 在文本中查找匹配(大小写不敏感)
|
||||
pattern = re.escape(content)
|
||||
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
||||
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
# 包含中文,使用更宽松的匹配
|
||||
search_pattern = pattern
|
||||
else:
|
||||
# 纯英文/数字,使用单词边界
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
||||
# 找到匹配,记录(去重)
|
||||
if content not in matched_jargon:
|
||||
matched_jargon[content] = {"content": content}
|
||||
|
||||
match_time = time.time()
|
||||
total_time = match_time - start_time
|
||||
query_duration = query_time - start_time
|
||||
match_duration = match_time - query_time
|
||||
|
||||
logger.debug(
|
||||
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
|
||||
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
|
||||
)
|
||||
|
||||
return list(matched_jargon.values())
|
||||
|
||||
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
chat_context: 聊天上下文的文本表示
|
||||
|
||||
Returns:
|
||||
Optional[str]: 黑话解释的概括文本,如果没有黑话则返回None
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 直接匹配方式:从数据库中查询jargon并在消息中匹配
|
||||
jargon_entries = self.match_jargon_from_messages(messages)
|
||||
|
||||
if not jargon_entries:
|
||||
return None
|
||||
|
||||
# 去重(按content)
|
||||
unique_jargon: Dict[str, Dict[str, str]] = {}
|
||||
for entry in jargon_entries:
|
||||
content = entry["content"]
|
||||
if content not in unique_jargon:
|
||||
unique_jargon[content] = entry
|
||||
|
||||
jargon_list = list(unique_jargon.values())
|
||||
logger.info(f"从上下文中提取到 {len(jargon_list)} 个黑话: {[j['content'] for j in jargon_list]}")
|
||||
|
||||
# 查询每个黑话的含义
|
||||
jargon_explanations: List[str] = []
|
||||
for entry in jargon_list:
|
||||
content = entry["content"]
|
||||
|
||||
# 根据是否开启全局黑话,决定查询方式
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启全局黑话:查询所有is_global=True的记录
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=None, # 不指定chat_id,查询全局黑话
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=False, # 精确匹配
|
||||
)
|
||||
else:
|
||||
# 关闭全局黑话:优先查询当前聊天或全局的黑话
|
||||
results = search_jargon(
|
||||
keyword=content,
|
||||
chat_id=self.chat_id,
|
||||
limit=1,
|
||||
case_sensitive=False,
|
||||
fuzzy=False, # 精确匹配
|
||||
)
|
||||
|
||||
if results and len(results) > 0:
|
||||
meaning = results[0].get("meaning", "").strip()
|
||||
if meaning:
|
||||
jargon_explanations.append(f"- {content}: {meaning}")
|
||||
else:
|
||||
logger.info(f"黑话 {content} 没有找到含义")
|
||||
else:
|
||||
logger.info(f"黑话 {content} 未在数据库中找到")
|
||||
|
||||
if not jargon_explanations:
|
||||
logger.info("没有找到任何黑话的含义,跳过解释")
|
||||
return None
|
||||
|
||||
# 拼接所有黑话解释
|
||||
explanations_text = "\n".join(jargon_explanations)
|
||||
|
||||
# 使用LLM概括黑话解释
|
||||
summarize_prompt = await global_prompt_manager.format_prompt(
|
||||
"jargon_explainer_summarize_prompt",
|
||||
chat_context=chat_context,
|
||||
jargon_explanations=explanations_text,
|
||||
)
|
||||
|
||||
summary, _ = await self.llm.generate_response_async(summarize_prompt, temperature=0.3)
|
||||
if not summary:
|
||||
# 如果LLM概括失败,直接返回原始解释
|
||||
return f"上下文中的黑话解释:\n{explanations_text}"
|
||||
|
||||
summary = summary.strip()
|
||||
if not summary:
|
||||
return f"上下文中的黑话解释:\n{explanations_text}"
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||
"""
|
||||
解释上下文中的黑话(便捷函数)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
messages: 消息列表
|
||||
chat_context: 聊天上下文的文本表示
|
||||
|
||||
Returns:
|
||||
Optional[str]: 黑话解释的概括文本,如果没有黑话则返回None
|
||||
"""
|
||||
explainer = JargonExplainer(chat_id)
|
||||
return await explainer.explain_jargon(messages, chat_context)
|
||||
|
||||
|
||||
def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||
"""直接在聊天文本中匹配已知的jargon,返回出现过的黑话列表
|
||||
|
||||
Args:
|
||||
chat_text: 要匹配的聊天文本
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
List[str]: 匹配到的黑话列表
|
||||
"""
|
||||
if not chat_text or not chat_text.strip():
|
||||
return []
|
||||
|
||||
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||
if global_config.expression.all_global_jargon:
|
||||
query = query.where(Jargon.is_global)
|
||||
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
matched: Dict[str, None] = {}
|
||||
|
||||
for jargon in query:
|
||||
content = (jargon.content or "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if not global_config.expression.all_global_jargon and not jargon.is_global:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
pattern = re.escape(content)
|
||||
if re.search(r"[\u4e00-\u9fff]", content):
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = r"\b" + pattern + r"\b"
|
||||
|
||||
if re.search(search_pattern, chat_text, re.IGNORECASE):
|
||||
matched[content] = None
|
||||
|
||||
logger.info(f"匹配到 {len(matched)} 个黑话")
|
||||
|
||||
return list(matched.keys())
|
||||
|
||||
|
||||
async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> str:
|
||||
"""对概念列表进行jargon检索
|
||||
|
||||
Args:
|
||||
concepts: 概念列表
|
||||
chat_id: 聊天ID
|
||||
|
||||
Returns:
|
||||
str: 检索结果字符串
|
||||
"""
|
||||
if not concepts:
|
||||
return ""
|
||||
|
||||
results = []
|
||||
exact_matches = [] # 收集所有精确匹配的概念
|
||||
for concept in concepts:
|
||||
concept = concept.strip()
|
||||
if not concept:
|
||||
continue
|
||||
|
||||
# 先尝试精确匹配
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=False)
|
||||
|
||||
is_fuzzy_match = False
|
||||
|
||||
# 如果精确匹配未找到,尝试模糊搜索
|
||||
if not jargon_results:
|
||||
jargon_results = search_jargon(keyword=concept, chat_id=chat_id, limit=10, case_sensitive=False, fuzzy=True)
|
||||
is_fuzzy_match = True
|
||||
|
||||
if jargon_results:
|
||||
# 找到结果
|
||||
if is_fuzzy_match:
|
||||
# 模糊匹配
|
||||
output_parts = [f"未精确匹配到'{concept}'"]
|
||||
for result in jargon_results:
|
||||
found_content = result.get("content", "").strip()
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if found_content and meaning:
|
||||
output_parts.append(f"找到 '{found_content}' 的含义为:{meaning}")
|
||||
results.append(",".join(output_parts))
|
||||
logger.info(f"在jargon库中找到匹配(模糊搜索): {concept},找到{len(jargon_results)}条结果")
|
||||
else:
|
||||
# 精确匹配
|
||||
output_parts = []
|
||||
for result in jargon_results:
|
||||
meaning = result.get("meaning", "").strip()
|
||||
if meaning:
|
||||
output_parts.append(f"'{concept}' 为黑话或者网络简写,含义为:{meaning}")
|
||||
results.append(";".join(output_parts) if len(output_parts) > 1 else output_parts[0])
|
||||
exact_matches.append(concept) # 收集精确匹配的概念,稍后统一打印
|
||||
else:
|
||||
# 未找到,不返回占位信息,只记录日志
|
||||
logger.info(f"在jargon库中未找到匹配: {concept}")
|
||||
|
||||
# 合并所有精确匹配的日志
|
||||
if exact_matches:
|
||||
logger.info(f"找到黑话: {', '.join(exact_matches)},共找到{len(exact_matches)}条结果")
|
||||
|
||||
if results:
|
||||
return "【概念检索结果】\n" + "\n".join(results) + "\n"
|
||||
return ""
|
||||
822
src/bw_learner/jargon_miner.py
Normal file
822
src/bw_learner/jargon_miner.py
Normal file
@@ -0,0 +1,822 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
from typing import List, Dict, Optional, Any
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.bw_learner.learner_utils import (
|
||||
is_bot_message,
|
||||
build_context_paragraph,
|
||||
contains_bot_self_name,
|
||||
parse_chat_id_list,
|
||||
chat_id_list_contains,
|
||||
update_chat_id_list,
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger("jargon")
|
||||
|
||||
|
||||
def _is_single_char_jargon(content: str) -> bool:
|
||||
"""
|
||||
判断是否是单字黑话(单个汉字、英文或数字)
|
||||
|
||||
Args:
|
||||
content: 词条内容
|
||||
|
||||
Returns:
|
||||
bool: 如果是单字黑话返回True,否则返回False
|
||||
"""
|
||||
if not content or len(content) != 1:
|
||||
return False
|
||||
|
||||
char = content[0]
|
||||
# 判断是否是单个汉字、单个英文字母或单个数字
|
||||
return (
|
||||
'\u4e00' <= char <= '\u9fff' or # 汉字
|
||||
'a' <= char <= 'z' or # 小写字母
|
||||
'A' <= char <= 'Z' or # 大写字母
|
||||
'0' <= char <= '9' # 数字
|
||||
)
|
||||
|
||||
|
||||
def _init_prompt() -> None:
|
||||
prompt_str = """
|
||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||
{chat_str}
|
||||
|
||||
请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。
|
||||
- 必须为对话中真实出现过的短词或短语
|
||||
- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语
|
||||
- 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等)
|
||||
- 每个词条长度建议 2-8 个字符(不强制),尽量短小
|
||||
|
||||
黑话必须为以下几种类型:
|
||||
- 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl
|
||||
- 英文词语的缩写,用英文字母概括一个词汇或含义,例如:CPU、GPU、API
|
||||
- 中文词语的缩写,用几个汉字概括一个词汇或含义,例如:社死、内卷
|
||||
|
||||
以 JSON 数组输出,元素为对象(严格按以下结构):
|
||||
请你提取出可能的黑话,最多30个黑话,请尽量提取所有
|
||||
[
|
||||
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
|
||||
{{"content": "词条2", "msg_id": "m15"}}
|
||||
]
|
||||
|
||||
现在请输出:
|
||||
"""
|
||||
Prompt(prompt_str, "extract_jargon_prompt")
|
||||
|
||||
|
||||
def _init_inference_prompts() -> None:
|
||||
"""初始化含义推断相关的prompt"""
|
||||
# Prompt 1: 基于raw_content和content推断
|
||||
prompt1_str = """
|
||||
**词条内容**
|
||||
{content}
|
||||
**词条出现的上下文。其中的{bot_name}的发言内容是你自己的发言**
|
||||
{raw_content_list}
|
||||
{previous_meaning_section}
|
||||
|
||||
请根据上下文,推断"{content}"这个词条的含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
- {bot_name} 的发言内容可能包含错误,请不要参考其发言内容
|
||||
- 如果上下文信息不足,无法推断含义,请设置 no_info 为 true
|
||||
{previous_meaning_instruction}
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)",
|
||||
"no_info": false
|
||||
}}
|
||||
注意:如果信息不足无法推断,请设置 "no_info": true,此时 meaning 可以为空字符串
|
||||
"""
|
||||
Prompt(prompt1_str, "jargon_inference_with_context_prompt")
|
||||
|
||||
# Prompt 2: 仅基于content推断
|
||||
prompt2_str = """
|
||||
**词条内容**
|
||||
{content}
|
||||
|
||||
请仅根据这个词条本身,推断其含义。
|
||||
- 如果这是一个黑话、俚语或网络用语,请推断其含义
|
||||
- 如果含义明确(常规词汇),也请说明
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"meaning": "详细含义说明(包含使用场景、来源、具体解释等)"
|
||||
}}
|
||||
"""
|
||||
Prompt(prompt2_str, "jargon_inference_content_only_prompt")
|
||||
|
||||
# Prompt 3: 比较两个推断结果
|
||||
prompt3_str = """
|
||||
**推断结果1(基于上下文)**
|
||||
{inference1}
|
||||
|
||||
**推断结果2(仅基于词条)**
|
||||
{inference2}
|
||||
|
||||
请比较这两个推断结果,判断它们是否相同或类似。
|
||||
- 如果两个推断结果的"含义"相同或类似,说明这个词条不是黑话(含义明确)
|
||||
- 如果两个推断结果有差异,说明这个词条可能是黑话(需要上下文才能理解)
|
||||
|
||||
以 JSON 格式输出:
|
||||
{{
|
||||
"is_similar": true/false,
|
||||
"reason": "判断理由"
|
||||
}}
|
||||
"""
|
||||
Prompt(prompt3_str, "jargon_compare_inference_prompt")
|
||||
|
||||
|
||||
_init_prompt()
|
||||
_init_inference_prompts()
|
||||
|
||||
|
||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||
"""
|
||||
判断是否需要进行含义推断
|
||||
在 count 达到 3,6, 10, 20, 40, 60, 100 时进行推断
|
||||
并且count必须大于last_inference_count,避免重启后重复判定
|
||||
如果is_complete为True,不再进行推断
|
||||
"""
|
||||
# 如果已完成所有推断,不再推断
|
||||
if jargon_obj.is_complete:
|
||||
return False
|
||||
|
||||
count = jargon_obj.count or 0
|
||||
last_inference = jargon_obj.last_inference_count or 0
|
||||
|
||||
# 阈值列表:3,6, 10, 20, 40, 60, 100
|
||||
thresholds = [2, 4, 8, 12, 24, 60, 100]
|
||||
|
||||
if count < thresholds[0]:
|
||||
return False
|
||||
|
||||
# 如果count没有超过上次判定值,不需要判定
|
||||
if count <= last_inference:
|
||||
return False
|
||||
|
||||
# 找到第一个大于last_inference的阈值
|
||||
next_threshold = None
|
||||
for threshold in thresholds:
|
||||
if threshold > last_inference:
|
||||
next_threshold = threshold
|
||||
break
|
||||
|
||||
# 如果没有找到下一个阈值,说明已经超过100,不应该再推断
|
||||
if next_threshold is None:
|
||||
return False
|
||||
|
||||
# 检查count是否达到或超过这个阈值
|
||||
return count >= next_threshold
|
||||
|
||||
|
||||
class JargonMiner:
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
|
||||
self.llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.extract",
|
||||
)
|
||||
|
||||
self.llm_inference = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="jargon.inference",
|
||||
)
|
||||
|
||||
# 初始化stream_name作为类属性,避免重复提取
|
||||
chat_manager = get_chat_manager()
|
||||
stream_name = chat_manager.get_stream_name(self.chat_id)
|
||||
self.stream_name = stream_name if stream_name else self.chat_id
|
||||
self.cache_limit = 50
|
||||
self.cache: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
# 黑话提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
def _add_to_cache(self, content: str) -> None:
|
||||
"""将提取到的黑话加入缓存,保持LRU语义"""
|
||||
if not content:
|
||||
return
|
||||
|
||||
key = content.strip()
|
||||
if not key:
|
||||
return
|
||||
|
||||
# 单字黑话(单个汉字、英文或数字)不记录到缓存
|
||||
if _is_single_char_jargon(key):
|
||||
return
|
||||
|
||||
if key in self.cache:
|
||||
self.cache.move_to_end(key)
|
||||
else:
|
||||
self.cache[key] = None
|
||||
if len(self.cache) > self.cache_limit:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
def _collect_cached_entries(self, messages: List[Any]) -> List[Dict[str, List[str]]]:
|
||||
"""检查缓存中的黑话是否出现在当前消息窗口,生成对应上下文"""
|
||||
if not self.cache or not messages:
|
||||
return []
|
||||
|
||||
cached_entries: List[Dict[str, List[str]]] = []
|
||||
processed_pairs = set()
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_text = (
|
||||
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||
).strip()
|
||||
if not msg_text or is_bot_message(msg):
|
||||
continue
|
||||
|
||||
for content in self.cache.keys():
|
||||
if not content:
|
||||
continue
|
||||
if (content, idx) in processed_pairs:
|
||||
continue
|
||||
if content in msg_text:
|
||||
paragraph = build_context_paragraph(messages, idx)
|
||||
if not paragraph:
|
||||
continue
|
||||
cached_entries.append({"content": content, "raw_content": [paragraph]})
|
||||
processed_pairs.add((content, idx))
|
||||
|
||||
return cached_entries
|
||||
|
||||
async def _infer_meaning_by_id(self, jargon_id: int) -> None:
|
||||
"""通过ID加载对象并推断"""
|
||||
try:
|
||||
jargon_obj = Jargon.get_by_id(jargon_id)
|
||||
# 再次检查is_complete,因为可能在异步任务执行时已被标记为完成
|
||||
if jargon_obj.is_complete:
|
||||
logger.debug(f"jargon {jargon_obj.content} 已完成所有推断,跳过")
|
||||
return
|
||||
await self.infer_meaning(jargon_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"通过ID推断jargon失败: {e}")
|
||||
|
||||
async def infer_meaning(self, jargon_obj: Jargon) -> None:
|
||||
"""
|
||||
对jargon进行含义推断
|
||||
"""
|
||||
try:
|
||||
content = jargon_obj.content
|
||||
raw_content_str = jargon_obj.raw_content or ""
|
||||
|
||||
# 解析raw_content列表
|
||||
raw_content_list = []
|
||||
if raw_content_str:
|
||||
try:
|
||||
raw_content_list = (
|
||||
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
|
||||
)
|
||||
if not isinstance(raw_content_list, list):
|
||||
raw_content_list = [raw_content_list] if raw_content_list else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
raw_content_list = [raw_content_str] if raw_content_str else []
|
||||
|
||||
if not raw_content_list:
|
||||
logger.warning(f"jargon {content} 没有raw_content,跳过推断")
|
||||
return
|
||||
|
||||
# 获取当前count和上一次的meaning
|
||||
current_count = jargon_obj.count or 0
|
||||
previous_meaning = jargon_obj.meaning or ""
|
||||
|
||||
# 当count为24, 60时,随机移除一半的raw_content项目
|
||||
if current_count in [24, 60] and len(raw_content_list) > 1:
|
||||
# 计算要保留的数量(至少保留1个)
|
||||
keep_count = max(1, len(raw_content_list) // 2)
|
||||
raw_content_list = random.sample(raw_content_list, keep_count)
|
||||
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
|
||||
|
||||
# 步骤1: 基于raw_content和content推断
|
||||
raw_content_text = "\n".join(raw_content_list)
|
||||
|
||||
# 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考
|
||||
previous_meaning_section = ""
|
||||
previous_meaning_instruction = ""
|
||||
if current_count in [24, 60, 100] and previous_meaning:
|
||||
previous_meaning_section = f"""
|
||||
**上一次推断的含义(仅供参考)**
|
||||
{previous_meaning}
|
||||
"""
|
||||
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
|
||||
|
||||
prompt1 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_with_context_prompt",
|
||||
content=content,
|
||||
bot_name=global_config.bot.nickname,
|
||||
raw_content_list=raw_content_text,
|
||||
previous_meaning_section=previous_meaning_section,
|
||||
previous_meaning_instruction=previous_meaning_instruction,
|
||||
)
|
||||
|
||||
response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3)
|
||||
if not response1:
|
||||
logger.warning(f"jargon {content} 推断1失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断1结果
|
||||
inference1 = None
|
||||
try:
|
||||
resp1 = response1.strip()
|
||||
if resp1.startswith("{") and resp1.endswith("}"):
|
||||
inference1 = json.loads(resp1)
|
||||
else:
|
||||
repaired = repair_json(resp1)
|
||||
inference1 = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(inference1, dict):
|
||||
logger.warning(f"jargon {content} 推断1结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断1解析失败: {e}")
|
||||
return
|
||||
|
||||
# 检查推断1是否表示信息不足无法推断
|
||||
no_info = inference1.get("no_info", False)
|
||||
meaning1 = inference1.get("meaning", "").strip()
|
||||
if no_info or not meaning1:
|
||||
logger.info(f"jargon {content} 推断1表示信息不足无法推断,放弃本次推断,待下次更新")
|
||||
# 更新最后一次判定的count值,避免在同一阈值重复尝试
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
jargon_obj.save()
|
||||
return
|
||||
|
||||
# 步骤2: 仅基于content推断
|
||||
prompt2 = await global_prompt_manager.format_prompt(
|
||||
"jargon_inference_content_only_prompt",
|
||||
content=content,
|
||||
)
|
||||
|
||||
response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3)
|
||||
if not response2:
|
||||
logger.warning(f"jargon {content} 推断2失败:无响应")
|
||||
return
|
||||
|
||||
# 解析推断2结果
|
||||
inference2 = None
|
||||
try:
|
||||
resp2 = response2.strip()
|
||||
if resp2.startswith("{") and resp2.endswith("}"):
|
||||
inference2 = json.loads(resp2)
|
||||
else:
|
||||
repaired = repair_json(resp2)
|
||||
inference2 = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(inference2, dict):
|
||||
logger.warning(f"jargon {content} 推断2结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 推断2解析失败: {e}")
|
||||
return
|
||||
|
||||
# logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
# logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
# logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
# logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.info(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.info(f"jargon {content} 推断1结果: {response1}")
|
||||
else:
|
||||
logger.debug(f"jargon {content} 推断2提示词: {prompt2}")
|
||||
logger.debug(f"jargon {content} 推断2结果: {response2}")
|
||||
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
|
||||
logger.debug(f"jargon {content} 推断1结果: {response1}")
|
||||
|
||||
# 步骤3: 比较两个推断结果
|
||||
prompt3 = await global_prompt_manager.format_prompt(
|
||||
"jargon_compare_inference_prompt",
|
||||
inference1=json.dumps(inference1, ensure_ascii=False),
|
||||
inference2=json.dumps(inference2, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon {content} 比较提示词: {prompt3}")
|
||||
|
||||
response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3)
|
||||
if not response3:
|
||||
logger.warning(f"jargon {content} 比较失败:无响应")
|
||||
return
|
||||
|
||||
# 解析比较结果
|
||||
comparison = None
|
||||
try:
|
||||
resp3 = response3.strip()
|
||||
if resp3.startswith("{") and resp3.endswith("}"):
|
||||
comparison = json.loads(resp3)
|
||||
else:
|
||||
repaired = repair_json(resp3)
|
||||
comparison = json.loads(repaired) if isinstance(repaired, str) else repaired
|
||||
if not isinstance(comparison, dict):
|
||||
logger.warning(f"jargon {content} 比较结果格式错误")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"jargon {content} 比较解析失败: {e}")
|
||||
return
|
||||
|
||||
# 判断是否为黑话
|
||||
is_similar = comparison.get("is_similar", False)
|
||||
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
|
||||
|
||||
# 更新数据库记录
|
||||
jargon_obj.is_jargon = is_jargon
|
||||
if is_jargon:
|
||||
# 是黑话,使用推断1的结果(基于上下文,更准确)
|
||||
jargon_obj.meaning = inference1.get("meaning", "")
|
||||
else:
|
||||
# 不是黑话,清空含义,不再存储任何内容
|
||||
jargon_obj.meaning = ""
|
||||
|
||||
# 更新最后一次判定的count值,避免重启后重复判定
|
||||
jargon_obj.last_inference_count = jargon_obj.count or 0
|
||||
|
||||
# 如果count>=100,标记为完成,不再进行推断
|
||||
if (jargon_obj.count or 0) >= 100:
|
||||
jargon_obj.is_complete = True
|
||||
|
||||
jargon_obj.save()
|
||||
logger.debug(
|
||||
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
|
||||
)
|
||||
|
||||
# 固定输出推断结果,格式化为可读形式
|
||||
if is_jargon:
|
||||
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
|
||||
meaning = jargon_obj.meaning or "无详细说明"
|
||||
is_global = jargon_obj.is_global
|
||||
if is_global:
|
||||
logger.info(f"[黑话]{content}的含义是 {meaning}")
|
||||
else:
|
||||
logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}")
|
||||
else:
|
||||
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
|
||||
logger.info(f"[{self.stream_name}]{content} 不是黑话")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"jargon推断失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def run_once(self, messages: List[Any]) -> None:
|
||||
"""
|
||||
运行一次黑话提取
|
||||
|
||||
Args:
|
||||
messages: 外部传入的消息列表(必需)
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
try:
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# 按时间排序,确保编号与上下文一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
chat_str, message_id_list = build_readable_messages_with_id(
|
||||
messages=messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
pic_single=True,
|
||||
)
|
||||
if not chat_str.strip():
|
||||
return
|
||||
|
||||
msg_id_to_index: Dict[str, int] = {}
|
||||
for idx, (msg_id, _msg) in enumerate(message_id_list or []):
|
||||
if not msg_id:
|
||||
continue
|
||||
msg_id_to_index[msg_id] = idx
|
||||
if not msg_id_to_index:
|
||||
logger.warning("未能生成消息ID映射,跳过本次提取")
|
||||
return
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"extract_jargon_prompt",
|
||||
bot_name=global_config.bot.nickname,
|
||||
chat_str=chat_str,
|
||||
)
|
||||
|
||||
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
|
||||
if not response:
|
||||
return
|
||||
|
||||
if global_config.debug.show_jargon_prompt:
|
||||
logger.info(f"jargon提取提示词: {prompt}")
|
||||
logger.info(f"jargon提取结果: {response}")
|
||||
|
||||
# 解析为JSON
|
||||
entries: List[dict] = []
|
||||
try:
|
||||
resp = response.strip()
|
||||
parsed = None
|
||||
if resp.startswith("[") and resp.endswith("]"):
|
||||
parsed = json.loads(resp)
|
||||
else:
|
||||
repaired = repair_json(resp)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
if not isinstance(parsed, list):
|
||||
return
|
||||
|
||||
for item in parsed:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
content = str(item.get("content", "")).strip()
|
||||
msg_id_value = item.get("msg_id")
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
if contains_bot_self_name(content):
|
||||
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
|
||||
continue
|
||||
|
||||
msg_id_str = str(msg_id_value or "").strip()
|
||||
if not msg_id_str:
|
||||
logger.warning(f"解析jargon失败:msg_id缺失,content={content}")
|
||||
continue
|
||||
|
||||
msg_index = msg_id_to_index.get(msg_id_str)
|
||||
if msg_index is None:
|
||||
logger.warning(f"解析jargon失败:msg_id未找到,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
target_msg = messages[msg_index]
|
||||
if is_bot_message(target_msg):
|
||||
logger.info(f"解析阶段跳过引用机器人自身消息的词条: content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
context_paragraph = build_context_paragraph(messages, msg_index)
|
||||
if not context_paragraph:
|
||||
logger.warning(f"解析jargon失败:上下文为空,content={content}, msg_id={msg_id_str}")
|
||||
continue
|
||||
|
||||
entries.append({"content": content, "raw_content": [context_paragraph]})
|
||||
cached_entries = self._collect_cached_entries(messages)
|
||||
if cached_entries:
|
||||
entries.extend(cached_entries)
|
||||
except Exception as e:
|
||||
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
|
||||
return
|
||||
|
||||
if not entries:
|
||||
return
|
||||
|
||||
# 去重并合并raw_content(按 content 聚合)
|
||||
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
|
||||
for entry in entries:
|
||||
content_key = entry["content"]
|
||||
raw_list = entry.get("raw_content", []) or []
|
||||
if content_key in merged_entries:
|
||||
merged_entries[content_key]["raw_content"].extend(raw_list)
|
||||
else:
|
||||
merged_entries[content_key] = {
|
||||
"content": content_key,
|
||||
"raw_content": list(raw_list),
|
||||
}
|
||||
|
||||
uniq_entries = []
|
||||
for merged_entry in merged_entries.values():
|
||||
raw_content_list = merged_entry["raw_content"]
|
||||
if raw_content_list:
|
||||
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
|
||||
uniq_entries.append(merged_entry)
|
||||
|
||||
saved = 0
|
||||
updated = 0
|
||||
for entry in uniq_entries:
|
||||
content = entry["content"]
|
||||
raw_content_list = entry["raw_content"] # 已经是列表
|
||||
|
||||
try:
|
||||
# 查询所有content匹配的记录
|
||||
query = Jargon.select().where(Jargon.content == content)
|
||||
|
||||
# 查找匹配的记录
|
||||
matched_obj = None
|
||||
for obj in query:
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有content匹配的记录都可以
|
||||
matched_obj = obj
|
||||
break
|
||||
else:
|
||||
# 关闭all_global:需要检查chat_id列表是否包含目标chat_id
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
if chat_id_list_contains(chat_id_list, self.chat_id):
|
||||
matched_obj = obj
|
||||
break
|
||||
|
||||
if matched_obj:
|
||||
obj = matched_obj
|
||||
try:
|
||||
obj.count = (obj.count or 0) + 1
|
||||
except Exception:
|
||||
obj.count = 1
|
||||
|
||||
# 合并raw_content列表:读取现有列表,追加新值,去重
|
||||
existing_raw_content = []
|
||||
if obj.raw_content:
|
||||
try:
|
||||
existing_raw_content = (
|
||||
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
|
||||
)
|
||||
if not isinstance(existing_raw_content, list):
|
||||
existing_raw_content = [existing_raw_content] if existing_raw_content else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
existing_raw_content = [obj.raw_content] if obj.raw_content else []
|
||||
|
||||
# 合并并去重
|
||||
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
|
||||
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
|
||||
|
||||
# 更新chat_id列表:增加当前chat_id的计数
|
||||
chat_id_list = parse_chat_id_list(obj.chat_id)
|
||||
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
|
||||
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
|
||||
|
||||
# 开启all_global时,确保记录标记为is_global=True
|
||||
if global_config.expression.all_global_jargon:
|
||||
obj.is_global = True
|
||||
# 关闭all_global时,保持原有is_global不变(不修改)
|
||||
|
||||
obj.save()
|
||||
|
||||
# 检查是否需要推断(达到阈值且超过上次判定值)
|
||||
if _should_infer_meaning(obj):
|
||||
# 异步触发推断,不阻塞主流程
|
||||
# 重新加载对象以确保数据最新
|
||||
jargon_id = obj.id
|
||||
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
|
||||
|
||||
updated += 1
|
||||
else:
|
||||
# 没找到匹配记录,创建新记录
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:新记录默认为is_global=True
|
||||
is_global_new = True
|
||||
else:
|
||||
# 关闭all_global:新记录is_global=False
|
||||
is_global_new = False
|
||||
|
||||
# 使用新格式创建chat_id列表:[[chat_id, count]]
|
||||
chat_id_list = [[self.chat_id, 1]]
|
||||
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
|
||||
|
||||
Jargon.create(
|
||||
content=content,
|
||||
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
|
||||
chat_id=chat_id_json,
|
||||
is_global=is_global_new,
|
||||
count=1,
|
||||
)
|
||||
saved += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
|
||||
continue
|
||||
finally:
|
||||
self._add_to_cache(content)
|
||||
|
||||
# 固定输出提取的jargon结果,格式化为可读形式(只要有提取结果就输出)
|
||||
if uniq_entries:
|
||||
# 收集所有提取的jargon内容
|
||||
jargon_list = [entry["content"] for entry in uniq_entries]
|
||||
jargon_str = ",".join(jargon_list)
|
||||
|
||||
# 输出格式化的结果(使用logger.info会自动应用jargon模块的颜色)
|
||||
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
|
||||
|
||||
if saved or updated:
|
||||
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated} 条,chat_id={self.chat_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"JargonMiner 运行失败: {e}")
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
|
||||
class JargonMinerManager:
|
||||
def __init__(self) -> None:
|
||||
self._miners: dict[str, JargonMiner] = {}
|
||||
|
||||
def get_miner(self, chat_id: str) -> JargonMiner:
|
||||
if chat_id not in self._miners:
|
||||
self._miners[chat_id] = JargonMiner(chat_id)
|
||||
return self._miners[chat_id]
|
||||
|
||||
|
||||
miner_manager = JargonMinerManager()
|
||||
|
||||
|
||||
|
||||
|
||||
def search_jargon(
|
||||
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
搜索jargon,支持大小写不敏感和模糊搜索
|
||||
|
||||
Args:
|
||||
keyword: 搜索关键词
|
||||
chat_id: 可选的聊天ID
|
||||
- 如果开启了all_global:此参数被忽略,查询所有is_global=True的记录
|
||||
- 如果关闭了all_global:如果提供则优先搜索该聊天或global的jargon
|
||||
limit: 返回结果数量限制,默认10
|
||||
case_sensitive: 是否大小写敏感,默认False(不敏感)
|
||||
fuzzy: 是否模糊搜索,默认True(使用LIKE匹配)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: 包含content, meaning的字典列表
|
||||
"""
|
||||
if not keyword or not keyword.strip():
|
||||
return []
|
||||
|
||||
keyword = keyword.strip()
|
||||
|
||||
# 构建查询(选择所有需要的字段,以便后续过滤)
|
||||
query = Jargon.select()
|
||||
|
||||
# 构建搜索条件
|
||||
if case_sensitive:
|
||||
# 大小写敏感
|
||||
if fuzzy:
|
||||
# 模糊搜索
|
||||
search_condition = Jargon.content.contains(keyword)
|
||||
else:
|
||||
# 精确匹配
|
||||
search_condition = Jargon.content == keyword
|
||||
else:
|
||||
# 大小写不敏感
|
||||
if fuzzy:
|
||||
# 模糊搜索(使用LOWER函数)
|
||||
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
|
||||
else:
|
||||
# 精确匹配(使用LOWER函数)
|
||||
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
|
||||
|
||||
query = query.where(search_condition)
|
||||
|
||||
# 根据all_global配置决定查询逻辑
|
||||
if global_config.expression.all_global_jargon:
|
||||
# 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id)
|
||||
query = query.where(Jargon.is_global)
|
||||
# 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式
|
||||
|
||||
# 注意:meaning的过滤移到Python层面,因为我们需要先过滤chat_id
|
||||
|
||||
# 按count降序排序,优先返回出现频率高的
|
||||
query = query.order_by(Jargon.count.desc())
|
||||
|
||||
# 限制结果数量(先多取一些,因为后面可能过滤)
|
||||
query = query.limit(limit * 2)
|
||||
|
||||
# 执行查询并返回结果,过滤chat_id
|
||||
results = []
|
||||
for jargon in query:
|
||||
# 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id
|
||||
if chat_id and not global_config.expression.all_global_jargon:
|
||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
||||
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
||||
continue
|
||||
|
||||
# 只返回有meaning的记录
|
||||
if not jargon.meaning or jargon.meaning.strip() == "":
|
||||
continue
|
||||
|
||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||
|
||||
# 达到限制数量后停止
|
||||
if len(results) >= limit:
|
||||
break
|
||||
|
||||
return results
|
||||
348
src/bw_learner/learner_utils.py
Normal file
348
src/bw_learner/learner_utils.py
Normal file
@@ -0,0 +1,348 @@
|
||||
import re
|
||||
import difflib
|
||||
import random
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
)
|
||||
from src.chat.utils.utils import parse_platform_accounts
|
||||
|
||||
|
||||
logger = get_logger("learner_utils")
|
||||
|
||||
|
||||
def filter_message_content(content: Optional[str]) -> str:
|
||||
"""
|
||||
过滤消息内容,移除回复、@、图片等格式
|
||||
|
||||
Args:
|
||||
content: 原始消息内容
|
||||
|
||||
Returns:
|
||||
str: 过滤后的内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
|
||||
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
|
||||
# 移除@<...>格式的内容
|
||||
content = re.sub(r"@<[^>]*>", "", content)
|
||||
# 移除[picid:...]格式的图片ID
|
||||
content = re.sub(r"\[picid:[^\]]*\]", "", content)
|
||||
# 移除[表情包:...]格式的内容
|
||||
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
|
||||
|
||||
return content.strip()
|
||||
|
||||
|
||||
def calculate_similarity(text1: str, text2: str) -> float:
|
||||
"""
|
||||
计算两个文本的相似度,返回0-1之间的值
|
||||
使用SequenceMatcher计算相似度
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
|
||||
Returns:
|
||||
float: 相似度值,范围0-1
|
||||
"""
|
||||
return difflib.SequenceMatcher(None, text1, text2).ratio()
|
||||
|
||||
|
||||
def format_create_date(timestamp: float) -> str:
|
||||
"""
|
||||
将时间戳格式化为可读的日期字符串
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
str: 格式化后的日期字符串
|
||||
"""
|
||||
try:
|
||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError):
|
||||
return "未知时间"
|
||||
|
||||
|
||||
def _compute_weights(population: List[Dict]) -> List[float]:
|
||||
"""
|
||||
根据表达的count计算权重,范围限定在1~5之间。
|
||||
count越高,权重越高,但最多为基础权重的5倍。
|
||||
如果表达已checked,权重会再乘以3倍。
|
||||
"""
|
||||
if not population:
|
||||
return []
|
||||
|
||||
counts = []
|
||||
checked_flags = []
|
||||
for item in population:
|
||||
count = item.get("count", 1)
|
||||
try:
|
||||
count_value = float(count)
|
||||
except (TypeError, ValueError):
|
||||
count_value = 1.0
|
||||
counts.append(max(count_value, 0.0))
|
||||
# 获取checked状态
|
||||
checked = item.get("checked", False)
|
||||
checked_flags.append(bool(checked))
|
||||
|
||||
min_count = min(counts)
|
||||
max_count = max(counts)
|
||||
|
||||
if max_count == min_count:
|
||||
base_weights = [1.0 for _ in counts]
|
||||
else:
|
||||
base_weights = []
|
||||
for count_value in counts:
|
||||
# 线性映射到[1,5]区间
|
||||
normalized = (count_value - min_count) / (max_count - min_count)
|
||||
base_weights.append(1.0 + normalized * 4.0) # 1~5
|
||||
|
||||
# 如果checked,权重乘以3
|
||||
weights = []
|
||||
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||
if checked:
|
||||
weights.append(base_weight * 3.0)
|
||||
else:
|
||||
weights.append(base_weight)
|
||||
return weights
|
||||
|
||||
|
||||
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
|
||||
"""
|
||||
随机抽样函数
|
||||
|
||||
Args:
|
||||
population: 总体数据列表
|
||||
k: 需要抽取的数量
|
||||
|
||||
Returns:
|
||||
List[Dict]: 抽取的数据列表
|
||||
"""
|
||||
if not population or k <= 0:
|
||||
return []
|
||||
|
||||
if len(population) <= k:
|
||||
return population.copy()
|
||||
|
||||
selected: List[Dict] = []
|
||||
population_copy = population.copy()
|
||||
|
||||
for _ in range(min(k, len(population_copy))):
|
||||
weights = _compute_weights(population_copy)
|
||||
total_weight = sum(weights)
|
||||
if total_weight <= 0:
|
||||
# 回退到均匀随机
|
||||
idx = random.randint(0, len(population_copy) - 1)
|
||||
selected.append(population_copy.pop(idx))
|
||||
continue
|
||||
|
||||
threshold = random.uniform(0, total_weight)
|
||||
cumulative = 0.0
|
||||
for idx, weight in enumerate(weights):
|
||||
cumulative += weight
|
||||
if threshold <= cumulative:
|
||||
selected.append(population_copy.pop(idx))
|
||||
break
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||
"""
|
||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||
|
||||
Args:
|
||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||
"""
|
||||
if not chat_id_value:
|
||||
return []
|
||||
|
||||
# 如果是字符串,尝试解析为JSON
|
||||
if isinstance(chat_id_value, str):
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
parsed = json.loads(chat_id_value)
|
||||
if isinstance(parsed, list):
|
||||
# 新格式:已经是列表
|
||||
return parsed
|
||||
elif isinstance(parsed, str):
|
||||
# 解析后还是字符串,说明是旧格式
|
||||
return [[parsed, 1]]
|
||||
else:
|
||||
# 其他类型,当作旧格式处理
|
||||
return [[str(chat_id_value), 1]]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 解析失败,当作旧格式(纯字符串)
|
||||
return [[str(chat_id_value), 1]]
|
||||
elif isinstance(chat_id_value, list):
|
||||
# 已经是列表格式
|
||||
return chat_id_value
|
||||
else:
|
||||
# 其他类型,转换为旧格式
|
||||
return [[str(chat_id_value), 1]]
|
||||
|
||||
|
||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||
"""
|
||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||
|
||||
Args:
|
||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要更新或添加的chat_id
|
||||
increment: 增加的计数,默认为1
|
||||
|
||||
Returns:
|
||||
List[List[Any]]: 更新后的chat_id列表
|
||||
"""
|
||||
item = _find_chat_id_item(chat_id_list, target_chat_id)
|
||||
if item is not None:
|
||||
# 找到匹配的chat_id,增加计数
|
||||
if len(item) >= 2:
|
||||
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
|
||||
else:
|
||||
item.append(increment)
|
||||
else:
|
||||
# 未找到,添加新条目
|
||||
chat_id_list.append([target_chat_id, increment])
|
||||
|
||||
return chat_id_list
|
||||
|
||||
|
||||
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
|
||||
"""
|
||||
在chat_id列表中查找匹配的项(辅助函数)
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
如果找到则返回匹配的项,否则返回None
|
||||
"""
|
||||
for item in chat_id_list:
|
||||
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||
"""
|
||||
检查chat_id列表中是否包含指定的chat_id
|
||||
|
||||
Args:
|
||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||
target_chat_id: 要查找的chat_id
|
||||
|
||||
Returns:
|
||||
bool: 如果包含则返回True
|
||||
"""
|
||||
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
|
||||
|
||||
|
||||
def contains_bot_self_name(content: str) -> bool:
|
||||
"""
|
||||
判断词条是否包含机器人的昵称或别名
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
target = content.strip().lower()
|
||||
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
|
||||
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
|
||||
|
||||
candidates = [name for name in [nickname, *alias_names] if name]
|
||||
|
||||
return any(name in target for name in candidates)
|
||||
|
||||
|
||||
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
|
||||
"""
|
||||
构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出
|
||||
"""
|
||||
if not messages or center_index < 0 or center_index >= len(messages):
|
||||
return None
|
||||
|
||||
context_start = max(0, center_index - 3)
|
||||
context_end = min(len(messages), center_index + 1 + 3)
|
||||
context_messages = messages[context_start:context_end]
|
||||
|
||||
if not context_messages:
|
||||
return None
|
||||
|
||||
try:
|
||||
paragraph = build_readable_messages(
|
||||
messages=context_messages,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
truncate=False,
|
||||
show_actions=False,
|
||||
show_pic=True,
|
||||
message_id_list=None,
|
||||
remove_emoji_stickers=False,
|
||||
pic_single=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"构建上下文段落失败: {e}")
|
||||
return None
|
||||
|
||||
paragraph = paragraph.strip()
|
||||
return paragraph or None
|
||||
|
||||
|
||||
def is_bot_message(msg: Any) -> bool:
|
||||
"""判断消息是否来自机器人自身"""
|
||||
if msg is None:
|
||||
return False
|
||||
|
||||
bot_config = getattr(global_config, "bot", None)
|
||||
if not bot_config:
|
||||
return False
|
||||
|
||||
platform = (
|
||||
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
|
||||
platform_accounts = {}
|
||||
try:
|
||||
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
|
||||
except Exception:
|
||||
platform_accounts = {}
|
||||
|
||||
bot_accounts: Dict[str, str] = {}
|
||||
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
|
||||
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
|
||||
for plat, account in platform_accounts.items():
|
||||
if account and plat not in bot_accounts:
|
||||
bot_accounts[plat] = account
|
||||
|
||||
bot_account = bot_accounts.get(platform)
|
||||
return bool(bot_account and user_id == bot_account)
|
||||
217
src/bw_learner/message_recorder.py
Normal file
217
src/bw_learner/message_recorder.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.bw_learner.expression_learner import expression_learner_manager
|
||||
from src.bw_learner.jargon_miner import miner_manager
|
||||
|
||||
logger = get_logger("bw_learner")
|
||||
|
||||
|
||||
class MessageRecorder:
|
||||
"""
|
||||
统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str) -> None:
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
||||
# 维护每个chat的上次提取时间
|
||||
self.last_extraction_time: float = time.time()
|
||||
|
||||
# 提取锁,防止并发执行
|
||||
self._extraction_lock = asyncio.Lock()
|
||||
|
||||
# 获取 expression 和 jargon 的配置参数
|
||||
self._init_parameters()
|
||||
|
||||
# 获取 expression_learner 和 jargon_miner 实例
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
|
||||
self.jargon_miner = miner_manager.get_miner(chat_id)
|
||||
|
||||
def _init_parameters(self) -> None:
|
||||
"""初始化提取参数"""
|
||||
# 获取 expression 配置
|
||||
_, self.enable_expression_learning, self.enable_jargon_learning = (
|
||||
global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
)
|
||||
self.min_messages_for_extraction = 30
|
||||
self.min_extraction_interval = 60
|
||||
|
||||
logger.debug(
|
||||
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
|
||||
f"min_messages={self.min_messages_for_extraction}, "
|
||||
f"min_interval={self.min_extraction_interval}"
|
||||
)
|
||||
|
||||
def should_trigger_extraction(self) -> bool:
|
||||
"""
|
||||
检查是否应该触发消息提取
|
||||
|
||||
Returns:
|
||||
bool: 是否应该触发提取
|
||||
"""
|
||||
# 检查时间间隔
|
||||
time_diff = time.time() - self.last_extraction_time
|
||||
if time_diff < self.min_extraction_interval:
|
||||
return False
|
||||
|
||||
# 检查消息数量
|
||||
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_extraction_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def extract_and_distribute(self) -> None:
|
||||
"""
|
||||
提取消息并分发给 expression_learner 和 jargon_miner
|
||||
"""
|
||||
# 使用异步锁防止并发执行
|
||||
async with self._extraction_lock:
|
||||
# 在锁内检查,避免并发触发
|
||||
if not self.should_trigger_extraction():
|
||||
return
|
||||
|
||||
# 检查 chat_stream 是否存在
|
||||
if not self.chat_stream:
|
||||
return
|
||||
|
||||
# 记录本次提取的时间窗口,避免重复提取
|
||||
extraction_start_time = self.last_extraction_time
|
||||
extraction_end_time = time.time()
|
||||
|
||||
# 立即更新提取时间,防止并发触发
|
||||
self.last_extraction_time = extraction_end_time
|
||||
|
||||
try:
|
||||
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
|
||||
|
||||
# 拉取提取窗口内的消息
|
||||
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=extraction_start_time,
|
||||
timestamp_end=extraction_end_time,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
|
||||
return
|
||||
|
||||
# 按时间排序,确保顺序一致
|
||||
messages = sorted(messages, key=lambda msg: msg.time or 0)
|
||||
|
||||
logger.info(
|
||||
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
|
||||
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
|
||||
)
|
||||
|
||||
|
||||
# 分别触发 expression_learner 和 jargon_miner 的处理
|
||||
# 传递提取的消息,避免它们重复获取
|
||||
# 触发 expression 学习(如果启用)
|
||||
if self.enable_expression_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
|
||||
)
|
||||
|
||||
# 触发 jargon 提取(如果启用),传递消息
|
||||
if self.enable_jargon_learning:
|
||||
asyncio.create_task(
|
||||
self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# 即使失败也保持时间戳更新,避免频繁重试
|
||||
|
||||
async def _trigger_expression_learning(
|
||||
self,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 expression 学习,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 ExpressionLearner(必需参数)
|
||||
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
|
||||
|
||||
if learnt_style:
|
||||
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
|
||||
else:
|
||||
logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def _trigger_jargon_extraction(
|
||||
self,
|
||||
timestamp_start: float,
|
||||
timestamp_end: float,
|
||||
messages: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
触发 jargon 提取,使用指定的消息列表
|
||||
|
||||
Args:
|
||||
timestamp_start: 开始时间戳
|
||||
timestamp_end: 结束时间戳
|
||||
messages: 消息列表
|
||||
"""
|
||||
try:
|
||||
# 传递消息给 JargonMiner,避免它重复获取
|
||||
await self.jargon_miner.run_once(messages=messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
class MessageRecorderManager:
|
||||
"""MessageRecorder 管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._recorders: dict[str, MessageRecorder] = {}
|
||||
|
||||
def get_recorder(self, chat_id: str) -> MessageRecorder:
|
||||
"""获取或创建指定 chat_id 的 MessageRecorder"""
|
||||
if chat_id not in self._recorders:
|
||||
self._recorders[chat_id] = MessageRecorder(chat_id)
|
||||
return self._recorders[chat_id]
|
||||
|
||||
|
||||
# 全局管理器实例
|
||||
recorder_manager = MessageRecorderManager()
|
||||
|
||||
|
||||
async def extract_and_distribute_messages(chat_id: str) -> None:
|
||||
"""
|
||||
统一的消息提取和分发入口函数
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
"""
|
||||
recorder = recorder_manager.get_recorder(chat_id)
|
||||
await recorder.extract_and_distribute()
|
||||
|
||||
199
src/bw_learner/reflect_tracker.py
Normal file
199
src/bw_learner/reflect_tracker.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import time
|
||||
from typing import Optional, Dict, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_logger("reflect_tracker")
|
||||
|
||||
|
||||
class ReflectTracker:
|
||||
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
||||
self.chat_stream = chat_stream
|
||||
self.expression = expression
|
||||
self.created_time = created_time
|
||||
# self.message_count = 0 # Replaced by checking message list length
|
||||
self.last_check_msg_count = 0
|
||||
self.max_message_count = 30
|
||||
self.max_duration = 15 * 60 # 15 minutes
|
||||
|
||||
# LLM for judging response
|
||||
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
|
||||
|
||||
self._init_prompts()
|
||||
|
||||
def _init_prompts(self):
|
||||
judge_prompt = """
|
||||
你是一个表达反思助手。Bot之前询问了表达方式是否合适。
|
||||
你需要根据提供的上下文对话,判断是否对该表达方式做出了肯定或否定的评价。
|
||||
|
||||
**询问内容**
|
||||
情景: {situation}
|
||||
风格: {style}
|
||||
|
||||
**上下文对话**
|
||||
{context_block}
|
||||
|
||||
**判断要求**
|
||||
1. 判断对话中是否包含对上述询问的回答。
|
||||
2. 如果是,判断是肯定(Approve)还是否定(Reject),或者是提供了修改意见。
|
||||
3. 如果不是回答,或者是无关内容,请返回 "Ignore"。
|
||||
4. 如果是否定并提供了修改意见,请提取修正后的情景和风格。
|
||||
|
||||
请输出JSON格式:
|
||||
```json
|
||||
{{
|
||||
"judgment": "Approve" | "Reject" | "Ignore",
|
||||
"corrected_situation": "...", // 如果有修改意见,提取修正后的情景,否则留空
|
||||
"corrected_style": "..." // 如果有修改意见,提取修正后的风格,否则留空
|
||||
}}
|
||||
```
|
||||
"""
|
||||
Prompt(judge_prompt, "reflect_judge_prompt")
|
||||
|
||||
async def trigger_tracker(self) -> bool:
|
||||
"""
|
||||
触发追踪检查
|
||||
Returns: True if resolved (should destroy tracker), False otherwise
|
||||
"""
|
||||
# Check timeout
|
||||
if time.time() - self.created_time > self.max_duration:
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).")
|
||||
return True
|
||||
|
||||
# Fetch messages since creation
|
||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
timestamp_start=self.created_time,
|
||||
timestamp_end=time.time(),
|
||||
)
|
||||
|
||||
current_msg_count = len(msg_list)
|
||||
|
||||
# Check message limit
|
||||
if current_msg_count > self.max_message_count:
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).")
|
||||
return True
|
||||
|
||||
# If no new messages since last check, skip
|
||||
if current_msg_count <= self.last_check_msg_count:
|
||||
return False
|
||||
|
||||
self.last_check_msg_count = current_msg_count
|
||||
|
||||
# Build context block
|
||||
# Use simple readable format
|
||||
context_block = build_readable_messages(
|
||||
msg_list,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
show_actions=False,
|
||||
)
|
||||
|
||||
# LLM Judge
|
||||
try:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"reflect_judge_prompt",
|
||||
situation=self.expression.situation,
|
||||
style=self.expression.style,
|
||||
context_block=context_block,
|
||||
)
|
||||
|
||||
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
||||
|
||||
response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1)
|
||||
|
||||
logger.info(f"ReflectTracker LLM Response: {response}")
|
||||
|
||||
# Parse JSON
|
||||
import json
|
||||
import re
|
||||
from json_repair import repair_json
|
||||
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if not matches:
|
||||
# Try to parse raw response if no code block
|
||||
matches = [response]
|
||||
|
||||
json_obj = json.loads(repair_json(matches[0]))
|
||||
|
||||
judgment = json_obj.get("judgment")
|
||||
|
||||
if judgment == "Approve":
|
||||
self.expression.checked = True
|
||||
self.expression.rejected = False
|
||||
self.expression.save()
|
||||
logger.info(f"Expression {self.expression.id} approved by operator.")
|
||||
return True
|
||||
|
||||
elif judgment == "Reject":
|
||||
self.expression.checked = True
|
||||
corrected_situation = json_obj.get("corrected_situation")
|
||||
corrected_style = json_obj.get("corrected_style")
|
||||
|
||||
# 检查是否有更新
|
||||
has_update = bool(corrected_situation or corrected_style)
|
||||
|
||||
if corrected_situation:
|
||||
self.expression.situation = corrected_situation
|
||||
if corrected_style:
|
||||
self.expression.style = corrected_style
|
||||
|
||||
# 如果拒绝但未更新,标记为 rejected=1
|
||||
if not has_update:
|
||||
self.expression.rejected = True
|
||||
else:
|
||||
self.expression.rejected = False
|
||||
|
||||
self.expression.save()
|
||||
|
||||
if has_update:
|
||||
logger.info(
|
||||
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
|
||||
)
|
||||
return True
|
||||
|
||||
elif judgment == "Ignore":
|
||||
logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ReflectTracker check: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Global manager for trackers
|
||||
class ReflectTrackerManager:
|
||||
def __init__(self):
|
||||
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
|
||||
|
||||
def add_tracker(self, chat_id: str, tracker: ReflectTracker):
|
||||
self.trackers[chat_id] = tracker
|
||||
|
||||
def get_tracker(self, chat_id: str) -> Optional[ReflectTracker]:
|
||||
return self.trackers.get(chat_id)
|
||||
|
||||
def remove_tracker(self, chat_id: str):
|
||||
if chat_id in self.trackers:
|
||||
del self.trackers[chat_id]
|
||||
|
||||
|
||||
reflect_tracker_manager = ReflectTrackerManager()
|
||||
Reference in New Issue
Block a user