chore: import deployable mai-bot source tree

This commit is contained in:
2026-05-11 00:51:12 +00:00
parent 4813699b3e
commit 7a54015f94
1009 changed files with 312999 additions and 16 deletions

View File

@@ -0,0 +1,367 @@
import sys
from dataclasses import dataclass, field
import pytest
import importlib
import importlib.util
from types import ModuleType
from pathlib import Path
from datetime import datetime
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from src.common.data_models.message_component_data_model import MessageSequence
from src.chat.message_receive.message import (
SessionMessage,
TextComponent,
ImageComponent,
AtComponent,
)
class DummyLogger:
def __init__(self) -> None:
self.logging_record = []
def debug(self, msg):
print(f"DEBUG: {msg}")
self.logging_record.append(f"DEBUG: {msg}")
def info(self, msg):
print(f"INFO: {msg}")
self.logging_record.append(f"INFO: {msg}")
def warning(self, msg):
print(f"WARNING: {msg}")
self.logging_record.append(f"WARNING: {msg}")
def error(self, msg):
print(f"ERROR: {msg}")
self.logging_record.append(f"ERROR: {msg}")
def critical(self, msg):
print(f"CRITICAL: {msg}")
self.logging_record.append(f"CRITICAL: {msg}")
def get_logger(name):
return DummyLogger()
class DummyDBSession:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def exec(self, statement):
return self
def first(self):
return None
def commit(self):
pass
def all(self):
return []
def get_db_session():
return DummyDBSession()
def get_manual_db_session():
return DummyDBSession()
class DummySelect:
def __init__(self, model):
self.model = model
def filter_by(self, **kwargs):
return self
def where(self, condition):
return self
def limit(self, n):
return self
def select(model):
return DummySelect(model)
async def dummy_get_voice_text(binary_data):
return None # 可以根据需要返回模拟的文本结果
class DummyPersonUtils:
@staticmethod
def get_person_info_by_user_id_and_platform(user_id, platform):
return None # 可以根据需要返回模拟的用户信息
class DummyConfig:
class MessageReceiveConfig:
ban_words = set()
ban_msgs_regex = set()
message_receive = MessageReceiveConfig()
@dataclass
class UserInfo:
user_id: str
user_nickname: str
user_cardname: Optional[str] = None
@dataclass
class GroupInfo:
group_id: str
group_name: str
@dataclass
class MessageInfo:
user_info: UserInfo
group_info: Optional[GroupInfo] = None
additional_config: dict = field(default_factory=dict)
def setup_mocks(monkeypatch):
def _stub_module(name: str) -> ModuleType:
module = ModuleType(name)
monkeypatch.setitem(sys.modules, name, module)
return module
# src.common.logger
logger_mod = _stub_module("src.common.logger")
# Mock the logger
logger_mod.get_logger = get_logger
db_mod = _stub_module("src.common.database.database")
db_mod.get_db_session = get_db_session
db_mod.get_manual_db_session = get_manual_db_session
db_model_mod = _stub_module("src.common.database.database_model")
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
voice_utils_mod.get_voice_text = dummy_get_voice_text
person_utils_mod = _stub_module("src.common.utils.utils_person")
person_utils_mod.PersonUtils = DummyPersonUtils
config_mod = _stub_module("src.config.config")
config_mod.global_config = DummyConfig()
def load_message_via_file(monkeypatch):
setup_mocks(monkeypatch)
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
spec = importlib.util.spec_from_file_location("message", file_path)
message_module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "message_module", message_module)
spec.loader.exec_module(message_module)
message_module.select = select
SessionMessageClass = message_module.SessionMessage
TextComponentClass = message_module.TextComponent
ImageComponentClass = message_module.ImageComponent
EmojiComponentClass = message_module.EmojiComponent
VoiceComponentClass = message_module.VoiceComponent
AtComponentClass = message_module.AtComponent
ReplyComponentClass = message_module.ReplyComponent
ForwardNodeComponentClass = message_module.ForwardNodeComponent
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
globals()["SessionMessage"] = SessionMessageClass
globals()["TextComponent"] = TextComponentClass
globals()["ImageComponent"] = ImageComponentClass
globals()["EmojiComponent"] = EmojiComponentClass
globals()["VoiceComponent"] = VoiceComponentClass
globals()["AtComponent"] = AtComponentClass
globals()["ReplyComponent"] = ReplyComponentClass
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
globals()["MessageSequence"] = MessageSequenceClass
globals()["ForwardComponent"] = ForwardComponentClass
return message_module
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
return "X" * length # 返回固定的字符串长度由参数决定模拟生成短ID的行为
def dummy_is_bot_self(platform, user_id: str) -> bool:
return user_id == "bot_self"
def load_utils_via_file(monkeypatch):
setup_mocks(monkeypatch)
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
math_utils_mod = ModuleType("src.common.utils.math_utils")
math_utils_mod.number_to_short_id = dummy_number_to_short_id
math_utils_mod.TimestampMode = type(
"TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
)
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
"2024-01-01 12:00:00"
) # 返回固定的时间字符串
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
for pkg_name in ["src", "src.common", "src.common.utils"]:
if pkg_name not in sys.modules:
pkg_mod = ModuleType(pkg_name)
pkg_mod.__path__ = []
monkeypatch.setitem(sys.modules, pkg_name, pkg_mod)
file_path = Path(__file__).parent.parent.parent / "src" / "common" / "utils" / "utils_message.py"
spec = importlib.util.spec_from_file_location("src.common.utils.utils_message", file_path)
utils_module = importlib.util.module_from_spec(spec)
utils_module.__package__ = "src.common.utils" # 设置包,使相对导入生效
monkeypatch.setitem(sys.modules, "src.common.utils.utils_message", utils_module)
monkeypatch.setitem(sys.modules, "message_utils_module", utils_module)
spec.loader.exec_module(utils_module)
utils_module.is_bot_self = dummy_is_bot_self
return utils_module
@pytest.mark.asyncio
async def test_message_utils(monkeypatch):
load_message_via_file(monkeypatch)
load_utils_via_file(monkeypatch)
@pytest.mark.asyncio
async def test_build_readable_message_basic(monkeypatch):
"""基础用例:单条消息,显示行号"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m1", datetime.now(), platform="test")
msg.platform = "test"
msg.session_id = "s_test"
user_info = UserInfo(user_id="u1", user_nickname="Alice")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("Hello world")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
assert "[1] Alice说Hello world" in text
assert mapping == {}
@pytest.mark.asyncio
async def test_build_readable_message_anonymize(monkeypatch):
"""匿名化用例:验证 mapping 和返回文本"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m2", datetime.now(), platform="test")
msg.session_id = "s_test"
user_info = UserInfo(user_id="u42", user_nickname="Bob")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("Secret text")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
# 根据实现original_name 为 user_nickname因此文本中应包含原始名称
assert "XXXXXX说" in text
assert "u42" in mapping
assert mapping["u42"][0] == "XXXXXX"
assert mapping["u42"][1] == "Bob"
@pytest.mark.asyncio
async def test_build_readable_message_replace_bot(monkeypatch):
"""替换机器人名用例:当 user_id 为 bot_self 时应被替换为 target_bot_name"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
msg = SessionMessage("m3", datetime.now(), platform="test")
msg.session_id = "s_test"
user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot")
msg.message_info = MessageInfo(user_info=user_info)
msg.raw_message = MessageSequence([TextComponent("ping")])
text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
assert "MAIBot说ping" in text
@pytest.mark.asyncio
async def test_build_readable_message_image_extraction(monkeypatch):
"""图片提取:验证 extract_pictures 为 True 时,文本中包含图片占位及 img_map 内容被返回"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建包含图片组件的消息
img = ImageComponent(binary_hash="h", binary_data=b"\x01\x02", content="Img")
msg = SessionMessage("mi1", datetime.now(), platform="test")
msg.session_id = "s_img"
msg.raw_message = MessageSequence([img])
msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser"))
text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True)
# 应包含图片描述占位
assert "图片1" in text
# mapping 不为空(匿名化未开启则为空)
assert isinstance(mapping, dict)
@pytest.mark.asyncio
async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(monkeypatch):
"""组合用例:多个消息同时包含匿名化、机器人名称替换"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建多个消息
msg1 = SessionMessage("m4", datetime.now(), platform="test")
msg1.session_id = "s_comb"
msg2 = SessionMessage("m5", datetime.now(), platform="test")
msg2.session_id = "s_comb"
msg1.message_info = MessageInfo(UserInfo(user_id="u_comb", user_nickname="Charlie"))
msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot"))
msg1.raw_message = MessageSequence([TextComponent("Hi")])
msg2.raw_message = MessageSequence([TextComponent("Hello")])
text, mapping, _ = await MessageUtils.build_readable_message(
[msg1, msg2],
anonymize=True,
replace_bot_name=True,
target_bot_name="MAIBot",
show_lineno=True,
)
# 验证文本内容
assert "[1] XXXXXX说Hi" in text
assert "[2] MAIBot说Hello" in text
# 验证 mapping 内容
assert "u_comb" in mapping
assert mapping["u_comb"][0] == "XXXXXX"
@pytest.mark.asyncio
async def test_build_readable_message_with_at(monkeypatch):
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
# 构建包含回复组件的消息
at_comp = AtComponent(target_user_id="u_at", target_user_nickname="AtUser")
msg = SessionMessage("m_at", datetime.now(), platform="test")
msg.session_id = "s_at"
msg.raw_message = MessageSequence([at_comp])
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
text, mapping, _ = await MessageUtils.build_readable_message(
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
)
# 验证主消息和@组件中的用户信息都被处理
assert "XXXXXX说" in text # 主消息用户被匿名化
assert "XXXXXX说@XXXXXX" in text # @组件用户被匿名化

View File

@@ -0,0 +1,117 @@
"""统计模块数据库会话行为测试。"""
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime, timedelta
from types import ModuleType
from typing import Any, Callable, Iterator
import sys
import pytest
from src.chat.utils import statistic
class _DummyResult:
"""模拟 SQLModel 查询结果对象。"""
def all(self) -> list[Any]:
"""返回空结果集。
Returns:
list[Any]: 空列表。
"""
return []
class _DummySession:
"""模拟数据库 Session。"""
def exec(self, statement: Any) -> _DummyResult:
"""执行查询语句并返回空结果。
Args:
statement: 待执行的查询语句。
Returns:
_DummyResult: 空结果对象。
"""
del statement
return _DummyResult()
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
"""构造一个记录 auto_commit 参数的假会话工厂。
Args:
calls: 用于记录每次调用 auto_commit 参数的列表。
Returns:
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
"""
@contextmanager
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
"""记录会话参数并返回假 Session。
Args:
auto_commit: 是否启用自动提交。
Yields:
Iterator[_DummySession]: 假 Session 对象。
"""
calls.append(auto_commit)
yield _DummySession()
return _fake_get_db_session
def _build_statistic_task() -> statistic.StatisticOutputTask:
"""构造一个最小可用的统计任务实例。
Returns:
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
"""
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
task.name_mapping = {}
return task
def _is_bot_self(platform: str, user_id: str) -> bool:
"""返回固定的非机器人身份判断结果。
Args:
platform: 平台名称。
user_id: 用户 ID。
Returns:
bool: 始终返回 ``False``。
"""
del platform
del user_id
return False
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
calls: list[bool] = []
now = datetime.now()
task = _build_statistic_task()
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
utils_module = ModuleType("src.chat.utils.utils")
utils_module.is_bot_self = _is_bot_self
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
monkeypatch.setattr(statistic, "fetch_online_time_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_model_usage_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_messages_since", lambda query_start_time: [])
monkeypatch.setattr(statistic, "fetch_tool_records_since", lambda query_start_time: [])
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
task._collect_interval_data(now, hours=1, interval_minutes=60)
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
assert calls == []

View File

@@ -0,0 +1,131 @@
from pathlib import Path
import json
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import ResponseRequest
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
from src.llm_models.request_snapshot import (
attach_request_snapshot,
deserialize_messages_snapshot,
format_request_snapshot_log_info,
save_failed_request_snapshot,
serialize_messages_snapshot,
serialize_response_request_snapshot,
)
from src.llm_models import request_snapshot
def _build_api_provider() -> APIProvider:
return APIProvider(
api_key="secret-token",
base_url="https://example.com/v1",
name="test-provider",
)
def _build_model_info() -> ModelInfo:
return ModelInfo(
api_provider="test-provider",
model_identifier="demo-model",
name="demo-model",
)
def _build_response_request() -> ResponseRequest:
tool_call = ToolCall(
args={"query": "MaiBot"},
call_id="call_1",
func_name="search_web",
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
)
message_list = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
MessageBuilder()
.set_role(RoleType.Tool)
.set_tool_call_id("call_1")
.set_tool_name("search_web")
.add_text_content('{"ok": true}')
.build(),
]
return ResponseRequest(
extra_params={"trace_id": "trace-123"},
max_tokens=256,
message_list=message_list,
model_info=_build_model_info(),
response_format=RespFormat(RespFormatType.JSON_OBJ),
temperature=0.2,
tool_options=[ToolOption(name="search_web", description="搜索网页")],
)
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
request = _build_response_request()
snapshot_messages = serialize_messages_snapshot(request.message_list)
restored_messages = deserialize_messages_snapshot(snapshot_messages)
assert len(restored_messages) == 3
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[1].role == RoleType.Assistant
assert restored_messages[1].tool_calls is not None
assert restored_messages[1].tool_calls[0].func_name == "search_web"
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
assert restored_messages[2].role == RoleType.Tool
assert restored_messages[2].tool_call_id == "call_1"
assert restored_messages[2].tool_name == "search_web"
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
provider = _build_api_provider()
snapshot_path = save_failed_request_snapshot(
api_provider=provider,
client_type="openai",
error=RuntimeError("boom"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
)
assert snapshot_path is not None
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
assert payload["internal_request"]["request_kind"] == "response"
assert payload["api_provider"]["name"] == "test-provider"
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
assert str(snapshot_path) in payload["replay"]["command"]
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
snapshot_path = save_failed_request_snapshot(
api_provider=_build_api_provider(),
client_type="openai",
error=ValueError("invalid"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"messages": []}},
)
assert snapshot_path is not None
exc = RuntimeError("wrapped")
attach_request_snapshot(exc, snapshot_path)
log_info = format_request_snapshot_log_info(exc)
assert str(snapshot_path) in log_info
assert snapshot_path.as_uri() in log_info
assert "uv run python scripts/replay_llm_request.py" in log_info

View File

@@ -0,0 +1,42 @@
from types import SimpleNamespace
from src.chat.message_receive.chat_manager import ChatManager
from src.common.utils.utils_session import SessionUtils
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
assert base_session_id == same_base_session_id
assert account_scoped_session_id != base_session_id
assert route_scoped_session_id != account_scoped_session_id
def test_chat_manager_register_message_uses_route_metadata() -> None:
chat_manager = ChatManager()
message = SimpleNamespace(
platform="qq",
session_id="",
message_info=SimpleNamespace(
user_info=SimpleNamespace(user_id="42"),
group_info=SimpleNamespace(group_id="1000"),
additional_config={
"platform_io_account_id": "123",
"platform_io_scope": "main",
},
),
)
chat_manager.register_message(message)
assert message.session_id == SessionUtils.calculate_session_id(
"qq",
user_id="42",
group_id="1000",
account_id="123",
scope="main",
)
assert chat_manager.last_messages[message.session_id] is message