Ruff fix
This commit is contained in:
@@ -4,14 +4,11 @@ import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import numpy as np
|
||||
|
||||
from collections import Counter
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.message_repository import find_messages, count_messages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
@@ -32,10 +29,10 @@ def is_english_letter(char: str) -> bool:
|
||||
|
||||
def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
"""解析 platforms 列表,返回平台到账号的映射
|
||||
|
||||
|
||||
Args:
|
||||
platforms: 格式为 ["platform:account"] 的列表,如 ["tg:123456789", "wx:wxid123"]
|
||||
|
||||
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
@@ -49,12 +46,12 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
@@ -72,12 +69,12 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
platform = getattr(message.message_info, "platform", "") or ""
|
||||
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
|
||||
@@ -146,7 +143,9 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float
|
||||
elif current_account:
|
||||
if re.search(rf"\[回复 (.+?)\({re.escape(current_account)}\):(.+?)\],说:", text):
|
||||
is_mentioned = True
|
||||
elif re.search(rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text):
|
||||
elif re.search(
|
||||
rf"\[回复<(.+?)(?=:{re.escape(current_account)}>)\:{re.escape(current_account)}>:(.+?)\],说:", text
|
||||
):
|
||||
is_mentioned = True
|
||||
|
||||
# 6) 名称/别名 提及(去除 @/回复标记后再匹配)
|
||||
@@ -185,7 +184,6 @@ async def get_embedding(text, request_type="embedding") -> Optional[List[float]]
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
"""将文本分割成句子,并根据概率合并
|
||||
1. 识别分割点(, , 。 ; 空格),但如果分割点左右都是英文字母则不分割。
|
||||
@@ -227,7 +225,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
|
||||
prev_char = text[i - 1]
|
||||
next_char = text[i + 1]
|
||||
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
|
||||
if char == ' ':
|
||||
if char == " ":
|
||||
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
|
||||
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
|
||||
if prev_is_alnum and next_is_alnum:
|
||||
@@ -340,7 +338,7 @@ def _get_random_default_reply() -> str:
|
||||
"不知道",
|
||||
"不晓得",
|
||||
"懒得说",
|
||||
"()"
|
||||
"()",
|
||||
]
|
||||
return random.choice(default_replies)
|
||||
|
||||
@@ -469,7 +467,6 @@ def calculate_typing_time(
|
||||
return total_time # 加上回车时间
|
||||
|
||||
|
||||
|
||||
def truncate_message(message: str, max_length=20) -> str:
|
||||
"""截断消息,使其不超过指定长度"""
|
||||
return f"{message[:max_length]}..." if len(message) > max_length else message
|
||||
@@ -546,7 +543,6 @@ def get_western_ratio(paragraph):
|
||||
return western_count / len(alnum_chars)
|
||||
|
||||
|
||||
|
||||
def translate_timestamp_to_human_readable(timestamp: float, mode: str = "normal") -> str:
|
||||
# sourcery skip: merge-comparisons, merge-duplicate-blocks, switch
|
||||
"""将时间戳转换为人类可读的时间格式
|
||||
|
||||
Reference in New Issue
Block a user