merge: sync upstream/r-dev and resolve real conflicts

This commit is contained in:
A-Dawn
2026-03-24 15:36:26 +08:00
114 changed files with 15841 additions and 5236 deletions

View File

@@ -0,0 +1,445 @@
import random
import json
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger("learner_utils")
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间。
count越高权重越高但最多为基础权重的5倍。
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
weights = [1.0 for _ in counts]
else:
weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 4.0) # 1~5
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
"""
解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
"""
if not chat_id_value:
return []
# 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str):
# 尝试解析JSON
try:
parsed = json.loads(chat_id_value)
if isinstance(parsed, list):
# 新格式:已经是列表
return parsed
elif isinstance(parsed, str):
# 解析后还是字符串,说明是旧格式
return [[parsed, 1]]
else:
# 其他类型,当作旧格式处理
return [[str(chat_id_value), 1]]
except (json.JSONDecodeError, TypeError):
# 解析失败,当作旧格式(纯字符串)
return [[str(chat_id_value), 1]]
elif isinstance(chat_id_value, list):
# 已经是列表格式
return chat_id_value
else:
# 其他类型,转换为旧格式
return [[str(chat_id_value), 1]]
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
"""
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1
Returns:
List[List[Any]]: 更新后的chat_id列表
"""
item = _find_chat_id_item(chat_id_list, target_chat_id)
if item is not None:
# 找到匹配的chat_id增加计数
if len(item) >= 2:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
else:
item.append(increment)
else:
# 未找到,添加新条目
chat_id_list.append([target_chat_id, increment])
return chat_id_list
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
"""
在chat_id列表中查找匹配的项辅助函数
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
如果找到则返回匹配的项否则返回None
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return item
return None
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
"""
检查chat_id列表中是否包含指定的chat_id
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
bool: 如果包含则返回True
"""
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
def contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates)
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身。"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not user_id:
return False
known_accounts = {
str(getattr(bot_config, "qq_account", "") or "").strip(),
str(getattr(bot_config, "telegram_account", "") or "").strip(),
}
for platform in getattr(bot_config, "platforms", []) or []:
account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip()
if account:
known_accounts.add(account)
return user_id in {account for account in known_accounts if account}
# def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
# """
# 构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出
# """
# if not messages or center_index < 0 or center_index >= len(messages):
# return None
# context_start = max(0, center_index - 3)
# context_end = min(len(messages), center_index + 1 + 3)
# context_messages = messages[context_start:context_end]
# if not context_messages:
# return None
# try:
# paragraph = build_readable_messages(
# messages=context_messages,
# replace_bot_name=True,
# timestamp_mode="relative",
# read_mark=0.0,
# truncate=False,
# show_actions=False,
# show_pic=True,
# message_id_list=None,
# remove_emoji_stickers=False,
# pic_single=True,
# )
# except Exception as e:
# logger.warning(f"构建上下文段落失败: {e}")
# return None
# paragraph = paragraph.strip()
# return paragraph or None
# def is_bot_message(msg: Any) -> bool:
# """判断消息是否来自机器人自身"""
# if msg is None:
# return False
# bot_config = getattr(global_config, "bot", None)
# if not bot_config:
# return False
# platform = (
# str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
# .strip()
# .lower()
# )
# user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
# if not platform or not user_id:
# return False
# platform_accounts = {}
# try:
# platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
# except Exception:
# platform_accounts = {}
# bot_accounts: Dict[str, str] = {}
# qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
# if qq_account:
# bot_accounts["qq"] = qq_account
# telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
# if telegram_account:
# bot_accounts["telegram"] = telegram_account
# for plat, account in platform_accounts.items():
# if account and plat not in bot_accounts:
# bot_accounts[plat] = account
# bot_account = bot_accounts.get(platform)
# return bool(bot_account and user_id == bot_account)
# def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# """
# 解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表。
# 期望的 JSON 结构:
# [
# {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
# {"content": "词条", "source_id": "12"}, // 黑话
# ...
# ]
# Returns:
# Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
# 第一个列表是表达方式 (situation, style, source_id)
# 第二个列表是黑话 (content, source_id)
# """
# if not response:
# return [], []
# raw = response.strip()
# # 尝试提取 ```json 代码块
# json_block_pattern = r"```json\s*(.*?)\s*```"
# match = re.search(json_block_pattern, raw, re.DOTALL)
# if match:
# raw = match.group(1).strip()
# else:
# # 去掉可能存在的通用 ``` 包裹
# raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
# raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
# raw = raw.strip()
# parsed = None
# expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
# jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
# try:
# # 优先尝试直接解析
# if raw.startswith("[") and raw.endswith("]"):
# parsed = json.loads(raw)
# else:
# repaired = repair_json(raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as parse_error:
# # 如果解析失败,尝试修复中文引号问题
# # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
# try:
# def fix_chinese_quotes_in_json(text):
# """使用状态机修复 JSON 字符串值中的中文引号"""
# result = []
# i = 0
# in_string = False
# escape_next = False
# while i < len(text):
# char = text[i]
# if escape_next:
# # 当前字符是转义字符后的字符,直接添加
# result.append(char)
# escape_next = False
# i += 1
# continue
# if char == "\\":
# # 转义字符
# result.append(char)
# escape_next = True
# i += 1
# continue
# if char == '"' and not escape_next:
# # 遇到英文引号,切换字符串状态
# in_string = not in_string
# result.append(char)
# i += 1
# continue
# if in_string:
# # 在字符串值内部,将中文引号替换为转义的英文引号
# if char == '"': # 中文左引号 U+201C
# result.append('\\"')
# elif char == '"': # 中文右引号 U+201D
# result.append('\\"')
# else:
# result.append(char)
# else:
# # 不在字符串内,直接添加
# result.append(char)
# i += 1
# return "".join(result)
# fixed_raw = fix_chinese_quotes_in_json(raw)
# # 再次尝试解析
# if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
# parsed = json.loads(fixed_raw)
# else:
# repaired = repair_json(fixed_raw)
# if isinstance(repaired, str):
# parsed = json.loads(repaired)
# else:
# parsed = repaired
# except Exception as fix_error:
# logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
# logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
# logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
# logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
# return [], []
# if isinstance(parsed, dict):
# parsed_list = [parsed]
# elif isinstance(parsed, list):
# parsed_list = parsed
# else:
# logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
# return [], []
# for item in parsed_list:
# if not isinstance(item, dict):
# continue
# # 检查是否是表达方式条目(有 situation 和 style
# situation = str(item.get("situation", "")).strip()
# style = str(item.get("style", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if situation and style and source_id:
# # 表达方式条目
# expressions.append((situation, style, source_id))
# elif item.get("content"):
# # 黑话条目(有 content 字段)
# content = str(item.get("content", "")).strip()
# source_id = str(item.get("source_id", "")).strip()
# if content and source_id:
# jargon_entries.append((content, source_id))
# return expressions, jargon_entries