fix: make bot identity platform-aware
This commit is contained in:
@@ -200,7 +200,7 @@ def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> st
|
||||
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
||||
|
||||
|
||||
def dummy_is_bot_self(user_id: str, platform) -> bool:
|
||||
def dummy_is_bot_self(platform, user_id: str) -> bool:
|
||||
return user_id == "bot_self"
|
||||
|
||||
|
||||
|
||||
187
pytests/utils_test/test_bot_identity_utils.py
Normal file
187
pytests/utils_test/test_bot_identity_utils.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.warning_messages: list[str] = []
|
||||
|
||||
def debug(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def info(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
def warning(self, msg: str) -> None:
|
||||
self.warning_messages.append(msg)
|
||||
|
||||
def error(self, _msg: str) -> None:
|
||||
return
|
||||
|
||||
|
||||
def load_utils_module(monkeypatch, qq_account=123456, platforms=None):
|
||||
logger = DummyLogger()
|
||||
configured_platforms = platforms or []
|
||||
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
for package_name in [
|
||||
"src",
|
||||
"src.chat",
|
||||
"src.chat.message_receive",
|
||||
"src.chat.utils",
|
||||
"src.common",
|
||||
"src.config",
|
||||
"src.llm_models",
|
||||
"src.person_info",
|
||||
]:
|
||||
if package_name not in sys.modules:
|
||||
package_module = ModuleType(package_name)
|
||||
package_module.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, package_name, package_module)
|
||||
|
||||
jieba_module = ModuleType("jieba")
|
||||
jieba_module.cut = lambda text: list(text)
|
||||
monkeypatch.setitem(sys.modules, "jieba", jieba_module)
|
||||
|
||||
logger_module = _stub_module("src.common.logger")
|
||||
logger_module.get_logger = lambda _name: logger
|
||||
|
||||
config_module = _stub_module("src.config.config")
|
||||
config_module.global_config = SimpleNamespace(
|
||||
bot=SimpleNamespace(
|
||||
qq_account=qq_account,
|
||||
platforms=configured_platforms,
|
||||
nickname="MaiBot",
|
||||
alias_names=[],
|
||||
),
|
||||
chat=SimpleNamespace(
|
||||
at_bot_inevitable_reply=1,
|
||||
mentioned_bot_reply=1,
|
||||
),
|
||||
)
|
||||
config_module.model_config = SimpleNamespace()
|
||||
|
||||
message_module = _stub_module("src.chat.message_receive.message")
|
||||
|
||||
class SessionMessage:
|
||||
pass
|
||||
|
||||
message_module.SessionMessage = SessionMessage
|
||||
|
||||
chat_manager_module = _stub_module("src.chat.message_receive.chat_manager")
|
||||
chat_manager_module.chat_manager = SimpleNamespace(get_session_by_session_id=lambda _chat_id: None)
|
||||
|
||||
llm_module = _stub_module("src.llm_models.utils_model")
|
||||
|
||||
class LLMRequest:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
llm_module.LLMRequest = LLMRequest
|
||||
|
||||
person_module = _stub_module("src.person_info.person_info")
|
||||
|
||||
class Person:
|
||||
pass
|
||||
|
||||
person_module.Person = Person
|
||||
|
||||
typo_generator_module = _stub_module("src.chat.utils.typo_generator")
|
||||
|
||||
class ChineseTypoGenerator:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
del args, kwargs
|
||||
|
||||
def create_typo_sentence(self, sentence: str):
|
||||
return sentence, ""
|
||||
|
||||
typo_generator_module.ChineseTypoGenerator = ChineseTypoGenerator
|
||||
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "utils" / "utils.py"
|
||||
spec = importlib.util.spec_from_file_location("src.chat.utils.utils", file_path)
|
||||
utils_module = importlib.util.module_from_spec(spec)
|
||||
utils_module.__package__ = "src.chat.utils"
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(utils_module)
|
||||
return utils_module, logger
|
||||
|
||||
|
||||
def test_platform_specific_bot_accounts(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=[" TG : tg_bot ", "discord: disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_bot_account("qq") == "123456"
|
||||
assert utils_module.get_bot_account("webui") == "123456"
|
||||
assert utils_module.get_bot_account("telegram") == "tg_bot"
|
||||
assert utils_module.get_bot_account("tg") == "tg_bot"
|
||||
assert utils_module.get_bot_account("discord") == "disc_bot"
|
||||
|
||||
assert utils_module.is_bot_self("qq", "123456")
|
||||
assert utils_module.is_bot_self("webui", "123456")
|
||||
assert utils_module.is_bot_self("telegram", "tg_bot")
|
||||
assert utils_module.is_bot_self(" TG ", "tg_bot")
|
||||
|
||||
|
||||
def test_get_all_bot_accounts_includes_runtime_aliases(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(
|
||||
monkeypatch,
|
||||
qq_account=123456,
|
||||
platforms=["TG:tg_bot", "discord:disc_bot"],
|
||||
)
|
||||
|
||||
assert utils_module.get_all_bot_accounts() == {
|
||||
"qq": "123456",
|
||||
"webui": "123456",
|
||||
"telegram": "tg_bot",
|
||||
"tg": "tg_bot",
|
||||
"discord": "disc_bot",
|
||||
}
|
||||
|
||||
|
||||
def test_unknown_platform_no_longer_falls_back_to_qq(monkeypatch):
|
||||
utils_module, logger = load_utils_module(monkeypatch, qq_account=123456, platforms=[])
|
||||
|
||||
assert utils_module.is_bot_self("unknown_platform", "123456") is False
|
||||
assert logger.warning_messages
|
||||
assert "unknown_platform" in logger.warning_messages[-1]
|
||||
|
||||
|
||||
def test_unconfigured_qq_account_disables_qq_and_webui_identity(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=0, platforms=["telegram:tg_bot"])
|
||||
|
||||
assert utils_module.get_bot_account("qq") == ""
|
||||
assert utils_module.get_bot_account("webui") == ""
|
||||
assert utils_module.is_bot_self("qq", "0") is False
|
||||
assert utils_module.is_bot_self("webui", "0") is False
|
||||
|
||||
|
||||
def test_is_mentioned_bot_in_message_uses_platform_account(monkeypatch):
|
||||
utils_module, _logger = load_utils_module(monkeypatch, qq_account=123456, platforms=["TG:tg_bot"])
|
||||
|
||||
message = SimpleNamespace(
|
||||
processed_plain_text="@tg_bot 你好",
|
||||
platform="telegram",
|
||||
is_mentioned=False,
|
||||
message_segment=None,
|
||||
message_info=SimpleNamespace(
|
||||
additional_config={},
|
||||
user_info=SimpleNamespace(user_id="user_1"),
|
||||
),
|
||||
)
|
||||
|
||||
is_mentioned, is_at, reply_probability = utils_module.is_mentioned_bot_in_message(message)
|
||||
|
||||
assert is_mentioned is True
|
||||
assert is_at is True
|
||||
assert reply_probability == 1.0
|
||||
@@ -155,7 +155,7 @@ class ExpressionLearner:
|
||||
|
||||
for i, msg in enumerate(self._messages_cache):
|
||||
# 跳过机器人自己的消息
|
||||
if is_bot_self(msg.message_info.user_info.user_id, msg.platform):
|
||||
if is_bot_self(msg.platform, msg.message_info.user_info.user_id):
|
||||
continue
|
||||
|
||||
# 获取消息文本
|
||||
@@ -238,7 +238,7 @@ class ExpressionLearner:
|
||||
|
||||
# 检查是否是机器人自己的消息
|
||||
target_msg = self._messages_cache[line_index]
|
||||
if is_bot_self(target_msg.message_info.user_info.user_id, target_msg.platform):
|
||||
if is_bot_self(target_msg.platform, target_msg.message_info.user_info.user_id):
|
||||
logger.info(f"跳过引用机器人自身消息的黑话:content={content}, source_id={source_id}")
|
||||
continue
|
||||
|
||||
@@ -298,7 +298,7 @@ class ExpressionLearner:
|
||||
# 当前行的原始消息
|
||||
current_msg = self._messages_cache[line_index]
|
||||
# 过滤掉从 bot 自己发言中提取到的表达方式
|
||||
if is_bot_self(current_msg.message_info.user_info.user_id, current_msg.platform):
|
||||
if is_bot_self(current_msg.platform, current_msg.message_info.user_info.user_id):
|
||||
continue
|
||||
# 过滤掉无上下文的表达方式
|
||||
context = (current_msg.processed_plain_text or "").strip()
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from maim_message import Seg
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -41,7 +43,7 @@ class DirectMessageSender:
|
||||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_id=get_bot_account(chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
)
|
||||
|
||||
|
||||
@@ -122,7 +122,11 @@ class ActionPlanner:
|
||||
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
|
||||
|
||||
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
|
||||
platform = message.platform or "qq"
|
||||
platform = message.platform or ""
|
||||
if not platform:
|
||||
logger.warning(
|
||||
f"{self.log_prefix}planner: message {message.message_id} has no platform set, bot-self detection will be skipped"
|
||||
)
|
||||
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
|
||||
|
||||
# 替换单独的 <用户名:用户ID> 格式(replace_user_references 已处理回复<和@<格式)
|
||||
@@ -135,7 +139,7 @@ class ActionPlanner:
|
||||
user_id = user_match.group(2)
|
||||
try:
|
||||
# 检查是否是机器人自己
|
||||
if user_id == global_config.bot.qq_account:
|
||||
if is_bot_self(platform, str(user_id)):
|
||||
return f"{global_config.bot.nickname}(你)"
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
return person.person_name or user_name
|
||||
|
||||
@@ -19,7 +19,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.services.message_service import (
|
||||
build_readable_messages,
|
||||
@@ -1122,7 +1122,7 @@ class DefaultReplyer:
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_id=get_bot_account(self.chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
additional_config={},
|
||||
|
||||
@@ -18,7 +18,7 @@ from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.utils.utils import get_bot_account, get_chat_type_and_target_info, is_bot_self
|
||||
from src.prompt.prompt_manager import prompt_manager
|
||||
from src.chat.utils.common_utils import TempMethodsExpression
|
||||
from src.services.message_service import (
|
||||
@@ -962,7 +962,7 @@ class PrivateReplyer:
|
||||
message_id=message_id,
|
||||
time=thinking_start_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_id=get_bot_account(self.chat_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
),
|
||||
group_info=None,
|
||||
|
||||
@@ -2106,12 +2106,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
total_replies = [0] * len(time_points)
|
||||
total_online_hours = [0.0] * len(time_points)
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
|
||||
interval_seconds = interval_hours * 3600
|
||||
|
||||
@@ -2148,7 +2143,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
if 0 <= interval_index < len(time_points):
|
||||
total_messages[interval_index] += 1
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
if is_bot_self(message.platform or "", message.user_id or ""):
|
||||
total_replies[interval_index] += 1
|
||||
|
||||
# 查询在线时间记录
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import jieba
|
||||
import json
|
||||
import ast
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
|
||||
import jieba
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -37,33 +37,64 @@ def parse_platform_accounts(platforms: list[str]) -> dict[str, str]:
|
||||
Returns:
|
||||
字典,键为平台名,值为账号
|
||||
"""
|
||||
result = {}
|
||||
result: dict[str, str] = {}
|
||||
for platform_entry in platforms:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
result[platform_name.strip()] = account.strip()
|
||||
normalized_platform = platform_name.lower().strip()
|
||||
account_str = account.strip()
|
||||
if normalized_platform and account_str:
|
||||
result[normalized_platform] = account_str
|
||||
return result
|
||||
|
||||
|
||||
def get_current_platform_account(platform: str, platform_accounts: dict[str, str], qq_account: str) -> str:
|
||||
"""根据当前平台获取对应的账号
|
||||
def _get_configured_qq_account() -> str:
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "")).strip()
|
||||
if qq_account in {"", "0"}:
|
||||
return ""
|
||||
return qq_account
|
||||
|
||||
Args:
|
||||
platform: 当前消息的平台
|
||||
platform_accounts: 从 platforms 列表解析的平台账号映射
|
||||
qq_account: QQ 账号(兼容旧配置)
|
||||
|
||||
Returns:
|
||||
当前平台对应的账号
|
||||
"""
|
||||
if platform == "qq":
|
||||
def get_bot_account(platform: str) -> str:
|
||||
"""根据当前平台获取对应的机器人账号。"""
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
if not normalized_platform:
|
||||
return ""
|
||||
|
||||
qq_account = _get_configured_qq_account()
|
||||
if normalized_platform in {"qq", "webui"}:
|
||||
return qq_account
|
||||
elif platform == "telegram":
|
||||
# 优先使用 tg,其次使用 telegram
|
||||
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
if normalized_platform in {"tg", "telegram"}:
|
||||
return platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
else:
|
||||
# 其他平台直接使用平台名作为键
|
||||
return platform_accounts.get(platform, "")
|
||||
|
||||
return platform_accounts.get(normalized_platform, "")
|
||||
|
||||
|
||||
def get_all_bot_accounts() -> dict[str, str]:
|
||||
"""获取所有已配置的机器人运行时身份。"""
|
||||
bot_accounts: dict[str, str] = {}
|
||||
qq_account = _get_configured_qq_account()
|
||||
if qq_account:
|
||||
bot_accounts["qq"] = qq_account
|
||||
bot_accounts["webui"] = qq_account
|
||||
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
|
||||
telegram_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
if telegram_account:
|
||||
bot_accounts["telegram"] = telegram_account
|
||||
bot_accounts["tg"] = telegram_account
|
||||
|
||||
for platform_name, account in platform_accounts.items():
|
||||
if platform_name in {"tg", "telegram"}:
|
||||
continue
|
||||
bot_accounts[platform_name] = account
|
||||
|
||||
return bot_accounts
|
||||
|
||||
|
||||
def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
@@ -78,39 +109,21 @@ def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
Returns:
|
||||
bool: 如果是机器人自己则返回 True,否则返回 False
|
||||
"""
|
||||
if not platform or not user_id:
|
||||
normalized_platform = str(platform or "").strip().lower()
|
||||
if not normalized_platform or not user_id:
|
||||
return False
|
||||
|
||||
# 将 user_id 转为字符串进行比较
|
||||
user_id_str = str(user_id)
|
||||
user_id_str = str(user_id).strip()
|
||||
if not user_id_str:
|
||||
return False
|
||||
|
||||
# 获取机器人的 QQ 账号(主账号)
|
||||
qq_account = str(global_config.bot.qq_account or "")
|
||||
bot_account = get_bot_account(normalized_platform)
|
||||
if bot_account:
|
||||
return user_id_str == bot_account
|
||||
|
||||
# QQ 平台:直接比较 QQ 账号
|
||||
if platform == "qq":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
|
||||
if platform == "webui":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# 获取各平台账号映射
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
|
||||
# Telegram 平台
|
||||
if platform == "telegram":
|
||||
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
return user_id_str == tg_account if tg_account else False
|
||||
|
||||
# 其他平台:尝试从 platforms 配置中查找
|
||||
platform_account = platform_accounts.get(platform, "")
|
||||
if platform_account:
|
||||
return user_id_str == platform_account
|
||||
|
||||
# 默认情况:与主 QQ 账号比较(兼容性)
|
||||
return user_id_str == qq_account
|
||||
logger.warning(f"平台 {normalized_platform} 未配置机器人账号,无法判断用户 {user_id_str} 是否为机器人自己")
|
||||
return False
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]:
|
||||
@@ -118,13 +131,8 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
||||
text = message.processed_plain_text or ""
|
||||
platform = message.platform or ""
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = parse_platform_accounts(platforms_list)
|
||||
qq_account = str(getattr(global_config.bot, "qq_account", "") or "")
|
||||
|
||||
# 获取当前平台对应的账号
|
||||
current_account = get_current_platform_account(platform, platform_accounts, qq_account)
|
||||
current_account = get_bot_account(platform)
|
||||
|
||||
nickname = str(global_config.bot.nickname or "")
|
||||
alias_names = list(getattr(global_config.bot, "alias_names", []) or [])
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import json
|
||||
import traceback
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import and_, func, not_, or_
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -163,7 +162,17 @@ def find_messages(
|
||||
after_time=after_time,
|
||||
)
|
||||
if filter_bot:
|
||||
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
||||
from src.chat.utils.utils import get_all_bot_accounts
|
||||
|
||||
bot_accounts = get_all_bot_accounts()
|
||||
if bot_accounts:
|
||||
bot_identity_predicate = or_(
|
||||
*[
|
||||
and_(Messages.platform == platform_name, Messages.user_id == account)
|
||||
for platform_name, account in bot_accounts.items()
|
||||
]
|
||||
)
|
||||
conditions.append(not_(bot_identity_predicate))
|
||||
if filter_command:
|
||||
conditions.append(Messages.is_command == False) # noqa: E712
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# TODO: 这个函数的实现非常临时,后续需要替换为更完善的实现,比如直接从配置文件中读取机器人自己的 ID,或者通过 API 获取机器人自己的信息等
|
||||
def is_bot_self(user_id: str, platform: str) -> bool:
|
||||
# TODO: 这个兼容包装层后续可以删除,统一直接使用 src.chat.utils.utils.is_bot_self
|
||||
def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""
|
||||
判断用户 ID 是否是机器人自己
|
||||
判断用户 ID 是否是机器人自己。
|
||||
|
||||
临时方法,后续会替换为更完善的实现
|
||||
当前仅保留兼容入口,真实实现委托给统一的多平台判断函数。
|
||||
"""
|
||||
return user_id == "bot_self" and platform == "test_platform"
|
||||
from src.chat.utils.utils import is_bot_self as _is_bot_self
|
||||
|
||||
return _is_bot_self(platform, user_id)
|
||||
|
||||
@@ -367,7 +367,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[msg_usr_info.user_id][0]
|
||||
new_message.message_info.user_info.user_nickname = anonymous_name
|
||||
new_message.message_info.user_info.user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(msg_usr_info.user_id, platform):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(platform, msg_usr_info.user_id):
|
||||
new_message.message_info.user_info.user_nickname = target_bot_name
|
||||
new_message.message_info.user_info.user_cardname = target_bot_name
|
||||
return new_message
|
||||
@@ -437,7 +437,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
component.target_user_nickname = anonymous_name
|
||||
component.target_user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(platform, user_id):
|
||||
component.target_user_nickname = target_bot_name
|
||||
component.target_user_cardname = target_bot_name
|
||||
return component
|
||||
@@ -473,7 +473,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
comp.user_nickname = anonymous_name
|
||||
comp.user_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(platform, user_id):
|
||||
comp.user_nickname = target_bot_name
|
||||
comp.user_cardname = target_bot_name
|
||||
comp.content = [ # 递归处理转发消息中的组件
|
||||
@@ -512,7 +512,7 @@ class MessageUtils:
|
||||
anonymous_name = anonymize_mapping[user_id][0]
|
||||
component.target_message_sender_nickname = anonymous_name
|
||||
component.target_message_sender_cardname = anonymous_name
|
||||
if replace_bot_name and target_bot_name and is_bot_self(user_id, platform):
|
||||
if replace_bot_name and target_bot_name and is_bot_self(platform, user_id):
|
||||
component.target_message_sender_nickname = target_bot_name
|
||||
component.target_message_sender_cardname = target_bot_name
|
||||
else:
|
||||
|
||||
@@ -256,43 +256,9 @@ class Person:
|
||||
Returns:
|
||||
bool: 如果是机器人自己则返回 True,否则返回 False
|
||||
"""
|
||||
if not platform or not user_id:
|
||||
return False
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
|
||||
# 将 user_id 转为字符串进行比较
|
||||
user_id_str = str(user_id)
|
||||
|
||||
# 获取机器人的 QQ 账号(主账号)
|
||||
qq_account = str(global_config.bot.qq_account or "")
|
||||
|
||||
# QQ 平台:直接比较 QQ 账号
|
||||
if platform == "qq":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# WebUI 平台:机器人回复时使用的是 QQ 账号,所以也比较 QQ 账号
|
||||
if platform == "webui":
|
||||
return user_id_str == qq_account
|
||||
|
||||
# 获取各平台账号映射
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
platform_accounts = {}
|
||||
for platform_entry in platforms_list:
|
||||
if ":" in platform_entry:
|
||||
platform_name, account = platform_entry.split(":", 1)
|
||||
platform_accounts[platform_name.strip()] = account.strip()
|
||||
|
||||
# Telegram 平台
|
||||
if platform == "telegram":
|
||||
tg_account = platform_accounts.get("tg", "") or platform_accounts.get("telegram", "")
|
||||
return user_id_str == tg_account if tg_account else False
|
||||
|
||||
# 其他平台:尝试从 platforms 配置中查找
|
||||
platform_account = platform_accounts.get(platform, "")
|
||||
if platform_account:
|
||||
return user_id_str == platform_account
|
||||
|
||||
# 默认情况:与主 QQ 账号比较(兼容性)
|
||||
return user_id_str == qq_account
|
||||
return is_bot_self(platform, user_id)
|
||||
|
||||
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
|
||||
# 使用统一的机器人识别函数(支持多平台,包括 WebUI)
|
||||
@@ -731,15 +697,15 @@ class PersonInfoManager:
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
qv_name_prompt = f"你是{bot_name},一个伪装成人类的AI,你不能让别人发现这一点,"
|
||||
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户的qq昵称是{user_nickname},"
|
||||
qv_name_prompt += f"用户的qq群昵称名是{user_cardname},"
|
||||
qv_name_prompt += f"现在你想给一个用户取一个昵称,用户的昵称是{user_nickname},"
|
||||
qv_name_prompt += f"用户的群昵称名是{user_cardname},"
|
||||
if user_avatar:
|
||||
qv_name_prompt += f"用户的qq头像是{user_avatar},"
|
||||
qv_name_prompt += f"用户的头像是{user_avatar},"
|
||||
if old_name:
|
||||
qv_name_prompt += f"你之前叫他{old_name},是因为{old_reason},"
|
||||
|
||||
qv_name_prompt += f"\n其他取名的要求是:{request},不要太浮夸,简短,"
|
||||
qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的qq昵称或群昵称原文,可以稍作修改,优先使用原文。优先使用用户的qq昵称或者群昵称原文。"
|
||||
qv_name_prompt += "\n请根据以上用户信息,想想你叫他什么比较好,不要太浮夸,请最好使用用户的昵称或群昵称原文,可以稍作修改,优先使用原文。优先使用用户的昵称或者群昵称原文。"
|
||||
|
||||
if existing_names_str:
|
||||
qv_name_prompt += f"\n请注意,以下名称已被你尝试过或已知存在,请避免:{existing_names_str}。\n"
|
||||
|
||||
@@ -4,15 +4,17 @@
|
||||
提供发送各种类型消息的核心功能。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo
|
||||
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.utils import get_bot_account
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from src.common.data_models.message_component_data_model import DictComponent, MessageSequence
|
||||
from src.common.logger import get_logger
|
||||
@@ -88,7 +90,7 @@ async def _send_to_target(
|
||||
message_id=message_id,
|
||||
time=current_time,
|
||||
user_info=MaimUserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_id=get_bot_account(target_stream.platform),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=target_stream.platform,
|
||||
),
|
||||
|
||||
@@ -11,11 +11,11 @@ from sqlmodel import col, delete, select
|
||||
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.utils.utils import is_bot_self
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import find_messages
|
||||
from src.common.utils.system_utils import is_bot_self
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.config.config import global_config
|
||||
from src.webui.core import get_token_manager
|
||||
@@ -62,12 +62,7 @@ class ChatHistoryManager:
|
||||
def _message_to_dict(self, msg: SessionMessage, group_id: Optional[str] = None) -> dict[str, Any]:
|
||||
user_info = msg.message_info.user_info
|
||||
user_id = user_info.user_id or ""
|
||||
is_bot = is_bot_self(user_id, msg.platform)
|
||||
|
||||
if not is_bot and group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
is_bot = user_id == str(global_config.bot.qq_account)
|
||||
elif not is_bot:
|
||||
is_bot = not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
is_bot = is_bot_self(msg.platform, user_id)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
@@ -611,4 +606,4 @@ async def dispatch_chat_event(
|
||||
)
|
||||
return current_user_name, next_virtual_config
|
||||
|
||||
return current_user_name, current_virtual_config
|
||||
return current_user_name, current_virtual_config
|
||||
|
||||
Reference in New Issue
Block a user