import random import json from typing import Optional, List, Dict, Any from src.common.logger import get_logger from src.config.config import global_config logger = get_logger("learner_utils") def _compute_weights(population: List[Dict]) -> List[float]: """ 根据表达的count计算权重,范围限定在1~5之间。 count越高,权重越高,但最多为基础权重的5倍。 """ if not population: return [] counts = [] for item in population: count = item.get("count", 1) try: count_value = float(count) except (TypeError, ValueError): count_value = 1.0 counts.append(max(count_value, 0.0)) min_count = min(counts) max_count = max(counts) if max_count == min_count: weights = [1.0 for _ in counts] else: weights = [] for count_value in counts: # 线性映射到[1,5]区间 normalized = (count_value - min_count) / (max_count - min_count) weights.append(1.0 + normalized * 4.0) # 1~5 return weights def weighted_sample(population: List[Dict], k: int) -> List[Dict]: """ 随机抽样函数 Args: population: 总体数据列表 k: 需要抽取的数量 Returns: List[Dict]: 抽取的数据列表 """ if not population or k <= 0: return [] if len(population) <= k: return population.copy() selected: List[Dict] = [] population_copy = population.copy() for _ in range(min(k, len(population_copy))): weights = _compute_weights(population_copy) total_weight = sum(weights) if total_weight <= 0: # 回退到均匀随机 idx = random.randint(0, len(population_copy) - 1) selected.append(population_copy.pop(idx)) continue threshold = random.uniform(0, total_weight) cumulative = 0.0 for idx, weight in enumerate(weights): cumulative += weight if threshold <= cumulative: selected.append(population_copy.pop(idx)) break return selected def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: """ 解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表) Args: chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式) Returns: List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表 """ if not chat_id_value: return [] # 如果是字符串,尝试解析为JSON if isinstance(chat_id_value, str): # 尝试解析JSON try: parsed = json.loads(chat_id_value) if isinstance(parsed, list): # 新格式:已经是列表 return parsed elif isinstance(parsed, str): # 解析后还是字符串,说明是旧格式 return [[parsed, 1]] else: # 其他类型,当作旧格式处理 return [[str(chat_id_value), 1]] except (json.JSONDecodeError, TypeError): # 解析失败,当作旧格式(纯字符串) return [[str(chat_id_value), 1]] elif isinstance(chat_id_value, list): # 已经是列表格式 return chat_id_value else: # 其他类型,转换为旧格式 return [[str(chat_id_value), 1]] def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]: """ 更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目 Args: chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...] target_chat_id: 要更新或添加的chat_id increment: 增加的计数,默认为1 Returns: List[List[Any]]: 更新后的chat_id列表 """ item = _find_chat_id_item(chat_id_list, target_chat_id) if item is not None: # 找到匹配的chat_id,增加计数 if len(item) >= 2: item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment else: item.append(increment) else: # 未找到,添加新条目 chat_id_list.append([target_chat_id, increment]) return chat_id_list def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]: """ 在chat_id列表中查找匹配的项(辅助函数) Args: chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] target_chat_id: 要查找的chat_id Returns: 如果找到则返回匹配的项,否则返回None """ for item in chat_id_list: if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): return item return None def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: """ 检查chat_id列表中是否包含指定的chat_id Args: chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] target_chat_id: 要查找的chat_id Returns: bool: 如果包含则返回True """ return _find_chat_id_item(chat_id_list, target_chat_id) is not None def contains_bot_self_name(content: str) -> bool: """ 判断词条是否包含机器人的昵称或别名 """ if not content: return False bot_config = getattr(global_config, "bot", None) if not bot_config: return False target = content.strip().lower() nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower() alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []] candidates = [name for name in [nickname, *alias_names] if name] return any(name in target for name in candidates) def is_bot_message(msg: Any) -> bool: """判断消息是否来自机器人自身。""" if msg is None: return False bot_config = getattr(global_config, "bot", None) if not bot_config: return False user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() if not user_id: return False known_accounts = { str(getattr(bot_config, "qq_account", "") or "").strip(), str(getattr(bot_config, "telegram_account", "") or "").strip(), } for platform in getattr(bot_config, "platforms", []) or []: account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip() if account: known_accounts.add(account) return user_id in {account for account in known_accounts if account} # def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: # """ # 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 # """ # if not messages or center_index < 0 or center_index >= len(messages): # return None # context_start = max(0, center_index - 3) # context_end = min(len(messages), center_index + 1 + 3) # context_messages = messages[context_start:context_end] # if not context_messages: # return None # try: # paragraph = build_readable_messages( # messages=context_messages, # replace_bot_name=True, # timestamp_mode="relative", # read_mark=0.0, # truncate=False, # show_actions=False, # show_pic=True, # message_id_list=None, # remove_emoji_stickers=False, # pic_single=True, # ) # except Exception as e: # logger.warning(f"构建上下文段落失败: {e}") # return None # paragraph = paragraph.strip() # return paragraph or None # def is_bot_message(msg: Any) -> bool: # """判断消息是否来自机器人自身""" # if msg is None: # return False # bot_config = getattr(global_config, "bot", None) # if not bot_config: # return False # platform = ( # str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "") # .strip() # .lower() # ) # user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() # if not platform or not user_id: # return False # platform_accounts = {} # try: # platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or []) # except Exception: # platform_accounts = {} # bot_accounts: Dict[str, str] = {} # qq_account = str(getattr(bot_config, "qq_account", "") or "").strip() # if qq_account: # bot_accounts["qq"] = qq_account # telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip() # if telegram_account: # bot_accounts["telegram"] = telegram_account # for plat, account in platform_accounts.items(): # if account and plat not in bot_accounts: # bot_accounts[plat] = account # bot_account = bot_accounts.get(platform) # return bool(bot_account and user_id == bot_account) # def parse_expression_response(response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: # """ # 解析 LLM 返回的表达风格总结和黑话 JSON,提取两个列表。 # 期望的 JSON 结构: # [ # {"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式 # {"content": "词条", "source_id": "12"}, // 黑话 # ... # ] # Returns: # Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]: # 第一个列表是表达方式 (situation, style, source_id) # 第二个列表是黑话 (content, source_id) # """ # if not response: # return [], [] # raw = response.strip() # # 尝试提取 ```json 代码块 # json_block_pattern = r"```json\s*(.*?)\s*```" # match = re.search(json_block_pattern, raw, re.DOTALL) # if match: # raw = match.group(1).strip() # else: # # 去掉可能存在的通用 ``` 包裹 # raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE) # raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE) # raw = raw.strip() # parsed = None # expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id) # jargon_entries: List[Tuple[str, str]] = [] # (content, source_id) # try: # # 优先尝试直接解析 # if raw.startswith("[") and raw.endswith("]"): # parsed = json.loads(raw) # else: # repaired = repair_json(raw) # if isinstance(repaired, str): # parsed = json.loads(repaired) # else: # parsed = repaired # except Exception as parse_error: # # 如果解析失败,尝试修复中文引号问题 # # 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号 # try: # def fix_chinese_quotes_in_json(text): # """使用状态机修复 JSON 字符串值中的中文引号""" # result = [] # i = 0 # in_string = False # escape_next = False # while i < len(text): # char = text[i] # if escape_next: # # 当前字符是转义字符后的字符,直接添加 # result.append(char) # escape_next = False # i += 1 # continue # if char == "\\": # # 转义字符 # result.append(char) # escape_next = True # i += 1 # continue # if char == '"' and not escape_next: # # 遇到英文引号,切换字符串状态 # in_string = not in_string # result.append(char) # i += 1 # continue # if in_string: # # 在字符串值内部,将中文引号替换为转义的英文引号 # if char == '"': # 中文左引号 U+201C # result.append('\\"') # elif char == '"': # 中文右引号 U+201D # result.append('\\"') # else: # result.append(char) # else: # # 不在字符串内,直接添加 # result.append(char) # i += 1 # return "".join(result) # fixed_raw = fix_chinese_quotes_in_json(raw) # # 再次尝试解析 # if fixed_raw.startswith("[") and fixed_raw.endswith("]"): # parsed = json.loads(fixed_raw) # else: # repaired = repair_json(fixed_raw) # if isinstance(repaired, str): # parsed = json.loads(repaired) # else: # parsed = repaired # except Exception as fix_error: # logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}") # logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}") # logger.error(f"解析表达风格 JSON 失败,原始响应:{response}") # logger.error(f"处理后的 JSON 字符串(前500字符):{raw[:500]}") # return [], [] # if isinstance(parsed, dict): # parsed_list = [parsed] # elif isinstance(parsed, list): # parsed_list = parsed # else: # logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}") # return [], [] # for item in parsed_list: # if not isinstance(item, dict): # continue # # 检查是否是表达方式条目(有 situation 和 style) # situation = str(item.get("situation", "")).strip() # style = str(item.get("style", "")).strip() # source_id = str(item.get("source_id", "")).strip() # if situation and style and source_id: # # 表达方式条目 # expressions.append((situation, style, source_id)) # elif item.get("content"): # # 黑话条目(有 content 字段) # content = str(item.get("content", "")).strip() # source_id = str(item.get("source_id", "")).strip() # if content and source_id: # jargon_entries.append((content, source_id)) # return expressions, jargon_entries