Files
mai-bot/src/learners/learner_utils_old.py

446 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import random
import json
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("learner_utils")
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间。
count越高权重越高但最多为基础权重的5倍。
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
weights = [1.0 for _ in counts]
else:
weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 4.0) # 1~5
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
"""
解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
"""
if not chat_id_value:
return []
# 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str):
# 尝试解析JSON
try:
parsed = json.loads(chat_id_value)
if isinstance(parsed, list):
# 新格式:已经是列表
return parsed
elif isinstance(parsed, str):
# 解析后还是字符串,说明是旧格式
return [[parsed, 1]]
else:
# 其他类型,当作旧格式处理
return [[str(chat_id_value), 1]]
except (json.JSONDecodeError, TypeError):
# 解析失败,当作旧格式(纯字符串)
return [[str(chat_id_value), 1]]
elif isinstance(chat_id_value, list):
# 已经是列表格式
return chat_id_value
else:
# 其他类型,转换为旧格式
return [[str(chat_id_value), 1]]
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
"""
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1
Returns:
List[List[Any]]: 更新后的chat_id列表
"""
item = _find_chat_id_item(chat_id_list, target_chat_id)
if item is not None:
# 找到匹配的chat_id增加计数
if len(item) >= 2:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
else:
item.append(increment)
else:
# 未找到,添加新条目
chat_id_list.append([target_chat_id, increment])
return chat_id_list
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
"""
在chat_id列表中查找匹配的项辅助函数
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
如果找到则返回匹配的项否则返回None
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return item
return None
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
"""
检查chat_id列表中是否包含指定的chat_id
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
bool: 如果包含则返回True
"""
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
def contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates)
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身。"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not user_id:
return False
known_accounts = {
str(getattr(bot_config, "qq_account", "") or "").strip(),
str(getattr(bot_config, "telegram_account", "") or "").strip(),
}
for platform in getattr(bot_config, "platforms", []) or []:
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
if account:
known_accounts.add(account)
return user_id in {account for account in known_accounts if account}
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
# """
# 构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出
# """
# if not messages or center_index < 0 or center_index >= len(messages):
# return None
# context_start = max(0, center_index - 3)
# context_end = min(len(messages), center_index + 1 + 3)
# context_messages = messages[context_start:context_end]
# if not context_messages:
# return None
# try:
# paragraph = build_readable_messages(
# messages=context_messages,
# replace_bot_name=True,
# timestamp_mode="relative",
# read_mark=0.0,
# truncate=False,
# show_actions=False,
# show_pic=True,
# message_id_list=None,
# remove_emoji_stickers=False,
# pic_single=True,
# )
# except Exception as e:
# logger.warning(f"构建上下文段落失败: {e}")
# return None
# paragraph = paragraph.strip()
# return paragraph or None
# def is_bot_message(msg: Any) -> bool:
# """判断消息是否来自机器人自身"""
# if msg is None:
# return False
# bot_config = getattr(global_config, "bot", None)
# if not bot_config:
# return False
# platform = (
# str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
# .strip()
# .lower()
# )
# user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
# if not platform or not user_id:
# return False
# platform_accounts = {}
# try:
# platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
# except Exception:
# platform_accounts = {}
# bot_accounts: Dict[str, str] = {}
# qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
# if qq_account:
# bot_accounts["qq"] = qq_account
# telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
# if telegram_account:
# bot_accounts["telegram"] = telegram_account
# for plat, account in platform_accounts.items():
# if account and plat not in bot_accounts:
# bot_accounts[plat] = account
# bot_account = bot_accounts.get(platform)
# return bool(bot_account and user_id == bot_account)
# def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# """
# 解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
# 期望的 JSON 结构:
# [
# {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
# {"content": "词条", "source_id": "12"}, // 黑话
# ...
# ]
# Returns:
# Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# 第一个列表是表达方式 (situation, style, source_id)
# 第二个列表是黑话 (content, source_id)
# """
# if not response:
# return [], []
# raw = response.strip()
# # 尝试提取 ```json 代码块
# json_block_pattern = r"```json\s*(.*?)\s*```"
# match = re.search(json_block_pattern, raw, re.DOTALL)
# if match:
# raw = match.group(1).strip()
# else:
# # 去掉可能存在的通用 ``` 包裹
# raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
# raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
# raw = raw.strip()
# parsed = None
# expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
# jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
# try:
# # 优先尝试直接解析
# if raw.startswith("[") and raw.endswith("]"):
# parsed = json.loads(raw)
# else:
# repaired = repair_json(raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as parse_error:
# # 如果解析失败,尝试修复中文引号问题
# # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
# try:
# def fix_chinese_quotes_in_json(text):
# """使用状态机修复 JSON 字符串值中的中文引号"""
# result = []
# i = 0
# in_string = False
# escape_next = False
# while i < len(text):
# char = text[i]
# if escape_next:
# # 当前字符是转义字符后的字符,直接添加
# result.append(char)
# escape_next = False
# i += 1
# continue
# if char == "\\":
# # 转义字符
# result.append(char)
# escape_next = True
# i += 1
# continue
# if char == '"' and not escape_next:
# # 遇到英文引号,切换字符串状态
# in_string = not in_string
# result.append(char)
# i += 1
# continue
# if in_string:
# # 在字符串值内部,将中文引号替换为转义的英文引号
# if char == '"': # 中文左引号 U+201C
# result.append('\\"')
# elif char == '"': # 中文右引号 U+201D
# result.append('\\"')
# else:
# result.append(char)
# else:
# # 不在字符串内,直接添加
# result.append(char)
# i += 1
# return "".join(result)
# fixed_raw = fix_chinese_quotes_in_json(raw)
# # 再次尝试解析
# if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
# parsed = json.loads(fixed_raw)
# else:
# repaired = repair_json(fixed_raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as fix_error:
# logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
# logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
# logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
# logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
# return [], []
# if isinstance(parsed, dict):
# parsed_list = [parsed]
# elif isinstance(parsed, list):
# parsed_list = parsed
# else:
# logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
# return [], []
# for item in parsed_list:
# if not isinstance(item, dict):
# continue
# # 检查是否是表达方式条目(有 situation 和 style
# situation = str(item.get("situation", "")).strip()
# style = str(item.get("style", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if situation and style and source_id:
# # 表达方式条目
# expressions.append((situation, style, source_id))
# elif item.get("content"):
# # 黑话条目(有 content 字段)
# content = str(item.get("content", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if content and source_id:
# jargon_entries.append((content, source_id))
# return expressions, jargon_entries