chore: import deployable mai-bot source tree
This commit is contained in:
367
pytests/utils_test/message_utils_test.py
Normal file
367
pytests/utils_test/message_utils_test.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
from src.chat.message_receive.message import (
|
||||
SessionMessage,
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
AtComponent,
|
||||
)
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.logging_record = []
|
||||
|
||||
def debug(self, msg):
|
||||
print(f"DEBUG: {msg}")
|
||||
self.logging_record.append(f"DEBUG: {msg}")
|
||||
|
||||
def info(self, msg):
|
||||
print(f"INFO: {msg}")
|
||||
self.logging_record.append(f"INFO: {msg}")
|
||||
|
||||
def warning(self, msg):
|
||||
print(f"WARNING: {msg}")
|
||||
self.logging_record.append(f"WARNING: {msg}")
|
||||
|
||||
def error(self, msg):
|
||||
print(f"ERROR: {msg}")
|
||||
self.logging_record.append(f"ERROR: {msg}")
|
||||
|
||||
def critical(self, msg):
|
||||
print(f"CRITICAL: {msg}")
|
||||
self.logging_record.append(f"CRITICAL: {msg}")
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
return DummyLogger()
|
||||
|
||||
|
||||
class DummyDBSession:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def exec(self, statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
def commit(self):
|
||||
pass
|
||||
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
|
||||
def get_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
def get_manual_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def filter_by(self, **kwargs):
|
||||
return self
|
||||
|
||||
def where(self, condition):
|
||||
return self
|
||||
|
||||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
def select(model):
|
||||
return DummySelect(model)
|
||||
|
||||
|
||||
async def dummy_get_voice_text(binary_data):
|
||||
return None # 可以根据需要返回模拟的文本结果
|
||||
|
||||
|
||||
class DummyPersonUtils:
|
||||
@staticmethod
|
||||
def get_person_info_by_user_id_and_platform(user_id, platform):
|
||||
return None # 可以根据需要返回模拟的用户信息
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
class MessageReceiveConfig:
|
||||
ban_words = set()
|
||||
ban_msgs_regex = set()
|
||||
|
||||
message_receive = MessageReceiveConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
user_cardname: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
group_id: str
|
||||
group_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageInfo:
|
||||
user_info: UserInfo
|
||||
group_info: Optional[GroupInfo] = None
|
||||
additional_config: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def setup_mocks(monkeypatch):
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
# src.common.logger
|
||||
logger_mod = _stub_module("src.common.logger")
|
||||
# Mock the logger
|
||||
logger_mod.get_logger = get_logger
|
||||
|
||||
db_mod = _stub_module("src.common.database.database")
|
||||
db_mod.get_db_session = get_db_session
|
||||
db_mod.get_manual_db_session = get_manual_db_session
|
||||
|
||||
db_model_mod = _stub_module("src.common.database.database_model")
|
||||
db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
emoji_manager_mod = _stub_module("src.emoji_system.emoji_manager")
|
||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
||||
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
|
||||
voice_utils_mod.get_voice_text = dummy_get_voice_text
|
||||
|
||||
person_utils_mod = _stub_module("src.common.utils.utils_person")
|
||||
person_utils_mod.PersonUtils = DummyPersonUtils
|
||||
|
||||
config_mod = _stub_module("src.config.config")
|
||||
config_mod.global_config = DummyConfig()
|
||||
|
||||
|
||||
def load_message_via_file(monkeypatch):
|
||||
setup_mocks(monkeypatch)
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
|
||||
spec = importlib.util.spec_from_file_location("message", file_path)
|
||||
message_module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "message_module", message_module)
|
||||
spec.loader.exec_module(message_module)
|
||||
message_module.select = select
|
||||
SessionMessageClass = message_module.SessionMessage
|
||||
TextComponentClass = message_module.TextComponent
|
||||
ImageComponentClass = message_module.ImageComponent
|
||||
EmojiComponentClass = message_module.EmojiComponent
|
||||
VoiceComponentClass = message_module.VoiceComponent
|
||||
AtComponentClass = message_module.AtComponent
|
||||
ReplyComponentClass = message_module.ReplyComponent
|
||||
ForwardNodeComponentClass = message_module.ForwardNodeComponent
|
||||
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
|
||||
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
|
||||
globals()["SessionMessage"] = SessionMessageClass
|
||||
globals()["TextComponent"] = TextComponentClass
|
||||
globals()["ImageComponent"] = ImageComponentClass
|
||||
globals()["EmojiComponent"] = EmojiComponentClass
|
||||
globals()["VoiceComponent"] = VoiceComponentClass
|
||||
globals()["AtComponent"] = AtComponentClass
|
||||
globals()["ReplyComponent"] = ReplyComponentClass
|
||||
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
|
||||
globals()["MessageSequence"] = MessageSequenceClass
|
||||
globals()["ForwardComponent"] = ForwardComponentClass
|
||||
return message_module
|
||||
|
||||
|
||||
def dummy_number_to_short_id(original_id: int, salt: str, length: int = 6) -> str:
|
||||
return "X" * length # 返回固定的字符串,长度由参数决定,模拟生成短ID的行为
|
||||
|
||||
|
||||
def dummy_is_bot_self(platform, user_id: str) -> bool:
|
||||
return user_id == "bot_self"
|
||||
|
||||
|
||||
def load_utils_via_file(monkeypatch):
|
||||
setup_mocks(monkeypatch)
|
||||
|
||||
# Mock math_utils 模块,供 from .math_utils import number_to_short_id 使用
|
||||
math_utils_mod = ModuleType("src.common.utils.math_utils")
|
||||
math_utils_mod.number_to_short_id = dummy_number_to_short_id
|
||||
math_utils_mod.TimestampMode = type(
|
||||
"TimestampMode", (), {"NORMAL": "%Y-%m-%d %H:%M:%S", "NORMAL_NO_YMD": "%H:%M:%S", "RELATIVE": "relative"}
|
||||
)
|
||||
math_utils_mod.translate_timestamp_to_human_readable = lambda timestamp, mode: (
|
||||
"2024-01-01 12:00:00"
|
||||
) # 返回固定的时间字符串
|
||||
monkeypatch.setitem(sys.modules, "src.common.utils.math_utils", math_utils_mod)
|
||||
|
||||
# 确保包层级模块存在于 sys.modules 中,使相对导入能正确解析
|
||||
for pkg_name in ["src", "src.common", "src.common.utils"]:
|
||||
if pkg_name not in sys.modules:
|
||||
pkg_mod = ModuleType(pkg_name)
|
||||
pkg_mod.__path__ = []
|
||||
monkeypatch.setitem(sys.modules, pkg_name, pkg_mod)
|
||||
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "common" / "utils" / "utils_message.py"
|
||||
spec = importlib.util.spec_from_file_location("src.common.utils.utils_message", file_path)
|
||||
utils_module = importlib.util.module_from_spec(spec)
|
||||
utils_module.__package__ = "src.common.utils" # 设置包,使相对导入生效
|
||||
monkeypatch.setitem(sys.modules, "src.common.utils.utils_message", utils_module)
|
||||
monkeypatch.setitem(sys.modules, "message_utils_module", utils_module)
|
||||
spec.loader.exec_module(utils_module)
|
||||
utils_module.is_bot_self = dummy_is_bot_self
|
||||
return utils_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_utils(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
load_utils_via_file(monkeypatch)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_basic(monkeypatch):
|
||||
"""基础用例:单条消息,显示行号"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
|
||||
msg = SessionMessage("m1", datetime.now(), platform="test")
|
||||
msg.platform = "test"
|
||||
msg.session_id = "s_test"
|
||||
user_info = UserInfo(user_id="u1", user_nickname="Alice")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("Hello world")])
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=False, show_lineno=True)
|
||||
assert "[1] Alice说:Hello world" in text
|
||||
assert mapping == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_anonymize(monkeypatch):
|
||||
"""匿名化用例:验证 mapping 和返回文本"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
|
||||
msg = SessionMessage("m2", datetime.now(), platform="test")
|
||||
msg.session_id = "s_test"
|
||||
user_info = UserInfo(user_id="u42", user_nickname="Bob")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("Secret text")])
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], anonymize=True, show_lineno=False)
|
||||
# 根据实现,original_name 为 user_nickname,因此文本中应包含原始名称
|
||||
assert "XXXXXX说:" in text
|
||||
assert "u42" in mapping
|
||||
assert mapping["u42"][0] == "XXXXXX"
|
||||
assert mapping["u42"][1] == "Bob"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_replace_bot(monkeypatch):
|
||||
"""替换机器人名用例:当 user_id 为 bot_self 时应被替换为 target_bot_name"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
|
||||
msg = SessionMessage("m3", datetime.now(), platform="test")
|
||||
msg.session_id = "s_test"
|
||||
user_info = UserInfo(user_id="bot_self", user_nickname="SomeBot")
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.raw_message = MessageSequence([TextComponent("ping")])
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], replace_bot_name=True, target_bot_name="MAIBot")
|
||||
assert "MAIBot说:ping" in text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_image_extraction(monkeypatch):
|
||||
"""图片提取:验证 extract_pictures 为 True 时,文本中包含图片占位及 img_map 内容被返回"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
|
||||
# 构建包含图片组件的消息
|
||||
img = ImageComponent(binary_hash="h", binary_data=b"\x01\x02", content="Img")
|
||||
msg = SessionMessage("mi1", datetime.now(), platform="test")
|
||||
msg.session_id = "s_img"
|
||||
msg.raw_message = MessageSequence([img])
|
||||
msg.message_info = MessageInfo(UserInfo(user_id="ui_img", user_nickname="ImgUser"))
|
||||
text, mapping, _ = await MessageUtils.build_readable_message([msg], extract_pictures=True)
|
||||
# 应包含图片描述占位
|
||||
assert "图片1" in text
|
||||
# mapping 不为空(匿名化未开启则为空)
|
||||
assert isinstance(mapping, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_anonymize_and_replace_bot_name_and_lineno(monkeypatch):
|
||||
"""组合用例:多个消息同时包含匿名化、机器人名称替换"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
# 构建多个消息
|
||||
msg1 = SessionMessage("m4", datetime.now(), platform="test")
|
||||
msg1.session_id = "s_comb"
|
||||
msg2 = SessionMessage("m5", datetime.now(), platform="test")
|
||||
msg2.session_id = "s_comb"
|
||||
msg1.message_info = MessageInfo(UserInfo(user_id="u_comb", user_nickname="Charlie"))
|
||||
msg2.message_info = MessageInfo(UserInfo(user_id="bot_self", user_nickname="SomeBot"))
|
||||
msg1.raw_message = MessageSequence([TextComponent("Hi")])
|
||||
msg2.raw_message = MessageSequence([TextComponent("Hello")])
|
||||
text, mapping, _ = await MessageUtils.build_readable_message(
|
||||
[msg1, msg2],
|
||||
anonymize=True,
|
||||
replace_bot_name=True,
|
||||
target_bot_name="MAIBot",
|
||||
show_lineno=True,
|
||||
)
|
||||
# 验证文本内容
|
||||
assert "[1] XXXXXX说:Hi" in text
|
||||
assert "[2] MAIBot说:Hello" in text
|
||||
# 验证 mapping 内容
|
||||
assert "u_comb" in mapping
|
||||
assert mapping["u_comb"][0] == "XXXXXX"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_readable_message_with_at(monkeypatch):
|
||||
"""包含@组件的消息:验证@组件中的用户信息也被匿名化和替换"""
|
||||
load_message_via_file(monkeypatch)
|
||||
utils_module = load_utils_via_file(monkeypatch)
|
||||
MessageUtils = utils_module.MessageUtils
|
||||
|
||||
# 构建包含回复组件的消息
|
||||
at_comp = AtComponent(target_user_id="u_at", target_user_nickname="AtUser")
|
||||
msg = SessionMessage("m_at", datetime.now(), platform="test")
|
||||
msg.session_id = "s_at"
|
||||
msg.raw_message = MessageSequence([at_comp])
|
||||
msg.message_info = MessageInfo(UserInfo(user_id="u_main", user_nickname="MainUser"))
|
||||
text, mapping, _ = await MessageUtils.build_readable_message(
|
||||
[msg], anonymize=True, replace_bot_name=True, target_bot_name="MAIBot"
|
||||
)
|
||||
# 验证主消息和@组件中的用户信息都被处理
|
||||
assert "XXXXXX说:" in text # 主消息用户被匿名化
|
||||
assert "XXXXXX说:@XXXXXX" in text # @组件用户被匿名化
|
||||
117
pytests/utils_test/statistic_test.py
Normal file
117
pytests/utils_test/statistic_test.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""统计模块数据库会话行为测试。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from src.chat.utils import statistic
|
||||
|
||||
|
||||
class _DummyResult:
|
||||
"""模拟 SQLModel 查询结果对象。"""
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
"""返回空结果集。
|
||||
|
||||
Returns:
|
||||
list[Any]: 空列表。
|
||||
"""
|
||||
return []
|
||||
|
||||
|
||||
class _DummySession:
|
||||
"""模拟数据库 Session。"""
|
||||
|
||||
def exec(self, statement: Any) -> _DummyResult:
|
||||
"""执行查询语句并返回空结果。
|
||||
|
||||
Args:
|
||||
statement: 待执行的查询语句。
|
||||
|
||||
Returns:
|
||||
_DummyResult: 空结果对象。
|
||||
"""
|
||||
del statement
|
||||
return _DummyResult()
|
||||
|
||||
|
||||
def _build_fake_get_db_session(calls: list[bool]) -> Callable[[bool], Iterator[_DummySession]]:
|
||||
"""构造一个记录 auto_commit 参数的假会话工厂。
|
||||
|
||||
Args:
|
||||
calls: 用于记录每次调用 auto_commit 参数的列表。
|
||||
|
||||
Returns:
|
||||
Callable[[bool], Iterator[_DummySession]]: 可替换 `get_db_session` 的上下文管理器工厂。
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def _fake_get_db_session(auto_commit: bool = True) -> Iterator[_DummySession]:
|
||||
"""记录会话参数并返回假 Session。
|
||||
|
||||
Args:
|
||||
auto_commit: 是否启用自动提交。
|
||||
|
||||
Yields:
|
||||
Iterator[_DummySession]: 假 Session 对象。
|
||||
"""
|
||||
calls.append(auto_commit)
|
||||
yield _DummySession()
|
||||
|
||||
return _fake_get_db_session
|
||||
|
||||
|
||||
def _build_statistic_task() -> statistic.StatisticOutputTask:
|
||||
"""构造一个最小可用的统计任务实例。
|
||||
|
||||
Returns:
|
||||
statistic.StatisticOutputTask: 跳过 `__init__` 的测试实例。
|
||||
"""
|
||||
task = statistic.StatisticOutputTask.__new__(statistic.StatisticOutputTask)
|
||||
task.name_mapping = {}
|
||||
return task
|
||||
|
||||
|
||||
def _is_bot_self(platform: str, user_id: str) -> bool:
|
||||
"""返回固定的非机器人身份判断结果。
|
||||
|
||||
Args:
|
||||
platform: 平台名称。
|
||||
user_id: 用户 ID。
|
||||
|
||||
Returns:
|
||||
bool: 始终返回 ``False``。
|
||||
"""
|
||||
del platform
|
||||
del user_id
|
||||
return False
|
||||
|
||||
|
||||
def test_statistic_read_queries_disable_auto_commit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""统计模块的纯读查询应关闭自动提交,避免 Session 退出后对象被 expire。"""
|
||||
calls: list[bool] = []
|
||||
now = datetime.now()
|
||||
task = _build_statistic_task()
|
||||
|
||||
monkeypatch.setattr(statistic, "get_db_session", _build_fake_get_db_session(calls))
|
||||
|
||||
utils_module = ModuleType("src.chat.utils.utils")
|
||||
utils_module.is_bot_self = _is_bot_self
|
||||
monkeypatch.setitem(sys.modules, "src.chat.utils.utils", utils_module)
|
||||
monkeypatch.setattr(statistic, "fetch_online_time_since", lambda query_start_time: [])
|
||||
monkeypatch.setattr(statistic, "fetch_model_usage_since", lambda query_start_time: [])
|
||||
monkeypatch.setattr(statistic, "fetch_messages_since", lambda query_start_time: [])
|
||||
monkeypatch.setattr(statistic, "fetch_tool_records_since", lambda query_start_time: [])
|
||||
|
||||
task._collect_message_count_for_period([("last_hour", now - timedelta(hours=1))])
|
||||
task._collect_interval_data(now, hours=1, interval_minutes=60)
|
||||
task._collect_metrics_interval_data(now, hours=1, interval_hours=1)
|
||||
|
||||
assert calls == []
|
||||
131
pytests/utils_test/test_request_snapshot.py
Normal file
131
pytests/utils_test/test_request_snapshot.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.model_client.base_client import ResponseRequest
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||
from src.llm_models.request_snapshot import (
|
||||
attach_request_snapshot,
|
||||
deserialize_messages_snapshot,
|
||||
format_request_snapshot_log_info,
|
||||
save_failed_request_snapshot,
|
||||
serialize_messages_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
from src.llm_models import request_snapshot
|
||||
|
||||
|
||||
def _build_api_provider() -> APIProvider:
|
||||
return APIProvider(
|
||||
api_key="secret-token",
|
||||
base_url="https://example.com/v1",
|
||||
name="test-provider",
|
||||
)
|
||||
|
||||
|
||||
def _build_model_info() -> ModelInfo:
|
||||
return ModelInfo(
|
||||
api_provider="test-provider",
|
||||
model_identifier="demo-model",
|
||||
name="demo-model",
|
||||
)
|
||||
|
||||
|
||||
def _build_response_request() -> ResponseRequest:
|
||||
tool_call = ToolCall(
|
||||
args={"query": "MaiBot"},
|
||||
call_id="call_1",
|
||||
func_name="search_web",
|
||||
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
|
||||
)
|
||||
message_list = [
|
||||
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
||||
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.Tool)
|
||||
.set_tool_call_id("call_1")
|
||||
.set_tool_name("search_web")
|
||||
.add_text_content('{"ok": true}')
|
||||
.build(),
|
||||
]
|
||||
return ResponseRequest(
|
||||
extra_params={"trace_id": "trace-123"},
|
||||
max_tokens=256,
|
||||
message_list=message_list,
|
||||
model_info=_build_model_info(),
|
||||
response_format=RespFormat(RespFormatType.JSON_OBJ),
|
||||
temperature=0.2,
|
||||
tool_options=[ToolOption(name="search_web", description="搜索网页")],
|
||||
)
|
||||
|
||||
|
||||
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
|
||||
request = _build_response_request()
|
||||
|
||||
snapshot_messages = serialize_messages_snapshot(request.message_list)
|
||||
restored_messages = deserialize_messages_snapshot(snapshot_messages)
|
||||
|
||||
assert len(restored_messages) == 3
|
||||
assert restored_messages[0].role == RoleType.User
|
||||
assert restored_messages[0].get_text_content() == "你好"
|
||||
assert restored_messages[0].parts[1].image_format == "png"
|
||||
assert restored_messages[1].role == RoleType.Assistant
|
||||
assert restored_messages[1].tool_calls is not None
|
||||
assert restored_messages[1].tool_calls[0].func_name == "search_web"
|
||||
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
|
||||
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
|
||||
assert restored_messages[2].role == RoleType.Tool
|
||||
assert restored_messages[2].tool_call_id == "call_1"
|
||||
assert restored_messages[2].tool_name == "search_web"
|
||||
|
||||
|
||||
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||
|
||||
request = _build_response_request()
|
||||
provider = _build_api_provider()
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=provider,
|
||||
client_type="openai",
|
||||
error=RuntimeError("boom"),
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=request.model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
|
||||
)
|
||||
|
||||
assert snapshot_path is not None
|
||||
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert payload["internal_request"]["request_kind"] == "response"
|
||||
assert payload["api_provider"]["name"] == "test-provider"
|
||||
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
|
||||
assert str(snapshot_path) in payload["replay"]["command"]
|
||||
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||
|
||||
request = _build_response_request()
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=_build_api_provider(),
|
||||
client_type="openai",
|
||||
error=ValueError("invalid"),
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=request.model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request={"request_kwargs": {"messages": []}},
|
||||
)
|
||||
|
||||
assert snapshot_path is not None
|
||||
exc = RuntimeError("wrapped")
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
|
||||
log_info = format_request_snapshot_log_info(exc)
|
||||
assert str(snapshot_path) in log_info
|
||||
assert snapshot_path.as_uri() in log_info
|
||||
assert "uv run python scripts/replay_llm_request.py" in log_info
|
||||
42
pytests/utils_test/test_session_utils.py
Normal file
42
pytests/utils_test/test_session_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from src.chat.message_receive.chat_manager import ChatManager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
|
||||
def test_calculate_session_id_distinguishes_account_and_scope() -> None:
|
||||
base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
same_base_session_id = SessionUtils.calculate_session_id("qq", user_id="42")
|
||||
account_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123")
|
||||
route_scoped_session_id = SessionUtils.calculate_session_id("qq", user_id="42", account_id="123", scope="main")
|
||||
|
||||
assert base_session_id == same_base_session_id
|
||||
assert account_scoped_session_id != base_session_id
|
||||
assert route_scoped_session_id != account_scoped_session_id
|
||||
|
||||
|
||||
def test_chat_manager_register_message_uses_route_metadata() -> None:
|
||||
chat_manager = ChatManager()
|
||||
message = SimpleNamespace(
|
||||
platform="qq",
|
||||
session_id="",
|
||||
message_info=SimpleNamespace(
|
||||
user_info=SimpleNamespace(user_id="42"),
|
||||
group_info=SimpleNamespace(group_id="1000"),
|
||||
additional_config={
|
||||
"platform_io_account_id": "123",
|
||||
"platform_io_scope": "main",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
chat_manager.register_message(message)
|
||||
|
||||
assert message.session_id == SessionUtils.calculate_session_id(
|
||||
"qq",
|
||||
user_id="42",
|
||||
group_id="1000",
|
||||
account_id="123",
|
||||
scope="main",
|
||||
)
|
||||
assert chat_manager.last_messages[message.session_id] is message
|
||||
Reference in New Issue
Block a user