fix: make bot identity platform-aware
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from maim_message import Seg
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -41,7 +43,7 @@ class DirectMessageSender:
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_id=get_bot_account(chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,7 +122,11 @@ class ActionPlanner:
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = message.platform or "qq"
|
||||
platform = message.platform or ""
|
||||
if not platform:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}planner: message {message.message_id} has no platform set, bot-self detection will be skipped"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
@@ -135,7 +139,7 @@ class ActionPlanner:
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if user_id == global_config.bot.qq_account:
|
||||
if is_bot_self(platform, str(user_id)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
|
||||
@@ -19,7 +19,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages,
|
||||
@@ -1122,7 +1122,7 @@ class DefaultReplyer:
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_id=get_bot_account(self.chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
additional_config={},
|
||||
|
||||
@@ -18,7 +18,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.services.message_service import (
|
||||
@@ -962,7 +962,7 @@ class PrivateReplyer:
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_id=get_bot_account(self.chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=None,
|
||||
|
||||
@@ -2106,12 +2106,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
total_replies = [0] * len(time_points)
|
||||
total_online_hours = [0.0] * len(time_points)
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
|
||||
interval_seconds = interval_hours * 3600
|
||||
|
||||
@@ -2148,7 +2143,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
total_messages[interval_index] += 1
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
if is_bot_self(message.platform or "", message.user_id or ""):
|
||||
total_replies[interval_index] += 1
|
||||
|
||||
# 查询在线时间记录
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
import jieba
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -37,33 +37,64 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
result = {}
|
||||
result: dict[str, str] = {}
|
||||
for platform_entry in platforms:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
result[platform_name.strip()] = account.strip()
|
||||
normalized_platform = platform_name.lower().strip()
|
||||
account_str = account.strip()
|
||||
if normalized_platform and account_str:
|
||||
result[normalized_platform] = account_str
|
||||
return result
|
||||
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
def _get_configured_qq_account() -> str:
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "")).strip()
|
||||
if qq_account in {"", "0"}:
|
||||
return ""
|
||||
return qq_account
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
if platform == "qq":
|
||||
def get_bot_account(platform: str) -> str:
|
||||
"""根据当前平台获取对应的机器人账号。"""
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
if not normalized_platform:
|
||||
return ""
|
||||
|
||||
qq_account = _get_configured_qq_account()
|
||||
if normalized_platform in {"qq", "webui"}:
|
||||
return qq_account
|
||||
elif platform == "telegram":
|
||||
# 优先使用 tg,其次使用 telegram
|
||||
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
if normalized_platform in {"tg", "telegram"}:
|
||||
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
else:
|
||||
# 其他平台直接使用平台名作为键
|
||||
return platform_accounts.get(platform, "")
|
||||
|
||||
return platform_accounts.get(normalized_platform, "")
|
||||
|
||||
|
||||
def get_all_bot_accounts() -> dict[str, str]:
|
||||
"""获取所有已配置的机器人运行时身份。"""
|
||||
bot_accounts: dict[str, str] = {}
|
||||
qq_account = _get_configured_qq_account()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
bot_accounts["webui"] = qq_account
|
||||
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
|
||||
telegram_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
bot_accounts["tg"] = telegram_account
|
||||
|
||||
for platform_name, account in platform_accounts.items():
|
||||
if platform_name in {"tg", "telegram"}:
|
||||
continue
|
||||
bot_accounts[platform_name] = account
|
||||
|
||||
return bot_accounts
|
||||
|
||||
|
||||
def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
@@ -78,39 +109,21 @@ def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
Returns:
|
||||
bool: 如果是机器人自己则返回 True,否则返回 False
|
||||
"""
|
||||
if not platform or not user_id:
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
if not normalized_platform or not user_id:
|
||||
return False
|
||||
|
||||
# 将 user_id 转为字符串进行比较
|
||||
user_id_str = str(user_id)
|
||||
user_id_str = str(user_id).strip()
|
||||
if not user_id_str:
|
||||
return False
|
||||
|
||||
# 获取机器人的 QQ 账号(主账号)
|
||||
qq_account = str(global_config.bot.qq_account or "")
|
||||
bot_account = get_bot_account(normalized_platform)
|
||||
if bot_account:
|
||||
return user_id_str == bot_account
|
||||
|
||||
# QQ 平台:直接比较 QQ 账号
|
||||
if platform == "qq":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
|
||||
if platform == "webui":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# 获取各平台账号映射
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
|
||||
# Telegram 平台
|
||||
if platform == "telegram":
|
||||
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
return user_id_str == tg_account if tg_account else False
|
||||
|
||||
# 其他平台:尝试从 platforms 配置中查找
|
||||
platform_account = platform_accounts.get(platform, "")
|
||||
if platform_account:
|
||||
return user_id_str == platform_account
|
||||
|
||||
# 默认情况:与主 QQ 账号比较(兼容性)
|
||||
return user_id_str == qq_account
|
||||
logger.warning(f"平台 {normalized_platform} 未配置机器人账号,无法判断用户 {user_id_str} 是否为机器人自己")
|
||||
return False
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]:
|
||||
@@ -118,13 +131,8 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
text = message.processed_plain_text or ""
|
||||
platform = message.platform or ""
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
current_account = get_bot_account(platform)
|
||||
|
||||
nickname = str(global_config.bot.nickname or "")
|
||||
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
|
||||
|
||||
Reference in New Issue
Block a user