全新的process方法完成(Message其他部分仍未完成);对应测试;调整部分注释;数据库检索优化
This commit is contained in:
420
pytests/message_test/session_message_test.py
Normal file
420
pytests/message_test/session_message_test.py
Normal file
@@ -0,0 +1,420 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
from types import ModuleType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.common.data_models.message_component_data_model import MessageSequence, ForwardComponent
|
||||
from src.chat.message_receive.message import (
|
||||
SessionMessage,
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
EmojiComponent,
|
||||
VoiceComponent,
|
||||
AtComponent,
|
||||
ReplyComponent,
|
||||
ForwardNodeComponent,
|
||||
StandardMessageComponents,
|
||||
)
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self) -> None:
|
||||
self.logging_record = []
|
||||
|
||||
def debug(self, msg):
|
||||
print(f"DEBUG: {msg}")
|
||||
self.logging_record.append(f"DEBUG: {msg}")
|
||||
|
||||
def info(self, msg):
|
||||
print(f"INFO: {msg}")
|
||||
self.logging_record.append(f"INFO: {msg}")
|
||||
|
||||
def warning(self, msg):
|
||||
print(f"WARNING: {msg}")
|
||||
self.logging_record.append(f"WARNING: {msg}")
|
||||
|
||||
def error(self, msg):
|
||||
print(f"ERROR: {msg}")
|
||||
self.logging_record.append(f"ERROR: {msg}")
|
||||
|
||||
def critical(self, msg):
|
||||
print(f"CRITICAL: {msg}")
|
||||
self.logging_record.append(f"CRITICAL: {msg}")
|
||||
|
||||
|
||||
def get_logger(name):
|
||||
return DummyLogger()
|
||||
|
||||
|
||||
class DummyDBSession:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
def exec(self, statement):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
def commit(self):
|
||||
pass
|
||||
|
||||
def all(self):
|
||||
return []
|
||||
|
||||
|
||||
def get_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
def get_manual_db_session():
|
||||
return DummyDBSession()
|
||||
|
||||
|
||||
class DummySelect:
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
|
||||
def filter_by(self, **kwargs):
|
||||
return self
|
||||
|
||||
def where(self, condition):
|
||||
return self
|
||||
|
||||
def limit(self, n):
|
||||
return self
|
||||
|
||||
|
||||
def select(model):
|
||||
return DummySelect(model)
|
||||
|
||||
|
||||
async def dummy_get_voice_text(binary_data):
|
||||
return None # 可以根据需要返回模拟的文本结果
|
||||
|
||||
|
||||
class DummyPersonUtils:
|
||||
@staticmethod
|
||||
def get_person_info_by_user_id_and_platform(user_id, platform):
|
||||
return None # 可以根据需要返回模拟的用户信息
|
||||
|
||||
|
||||
def setup_mocks(monkeypatch):
|
||||
def _stub_module(name: str) -> ModuleType:
|
||||
module = ModuleType(name)
|
||||
monkeypatch.setitem(sys.modules, name, module)
|
||||
return module
|
||||
|
||||
# src.common.logger
|
||||
logger_mod = _stub_module("src.common.logger")
|
||||
# Mock the logger
|
||||
logger_mod.get_logger = get_logger
|
||||
|
||||
db_mod = _stub_module("src.common.database.database")
|
||||
db_mod.get_db_session = get_db_session
|
||||
db_mod.get_manual_db_session = get_manual_db_session
|
||||
|
||||
emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager")
|
||||
emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
image_manager_mod = _stub_module("src.chat.image_system.image_manager")
|
||||
image_manager_mod.image_manager = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
msg_utils_mod = _stub_module("src.common.utils.utils_message")
|
||||
msg_utils_mod.MessageUtils = None # 可以根据需要添加更多的属性或方法
|
||||
|
||||
voice_utils_mod = _stub_module("src.common.utils.utils_voice")
|
||||
voice_utils_mod.get_voice_text = dummy_get_voice_text
|
||||
|
||||
person_utils_mod = _stub_module("src.common.utils.utils_person")
|
||||
person_utils_mod.PersonUtils = DummyPersonUtils
|
||||
|
||||
|
||||
def load_message_via_file(monkeypatch):
|
||||
setup_mocks(monkeypatch)
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "chat" / "message_receive" / "message.py"
|
||||
spec = importlib.util.spec_from_file_location("message", file_path)
|
||||
message_module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "message_module", message_module)
|
||||
spec.loader.exec_module(message_module)
|
||||
message_module.select = select
|
||||
SessionMessageClass = message_module.SessionMessage
|
||||
TextComponentClass = message_module.TextComponent
|
||||
ImageComponentClass = message_module.ImageComponent
|
||||
EmojiComponentClass = message_module.EmojiComponent
|
||||
VoiceComponentClass = message_module.VoiceComponent
|
||||
AtComponentClass = message_module.AtComponent
|
||||
ReplyComponentClass = message_module.ReplyComponent
|
||||
ForwardNodeComponentClass = message_module.ForwardNodeComponent
|
||||
MessageSequenceClass = sys.modules["src.common.data_models.message_component_data_model"].MessageSequence
|
||||
ForwardComponentClass = sys.modules["src.common.data_models.message_component_data_model"].ForwardComponent
|
||||
globals()["SessionMessage"] = SessionMessageClass
|
||||
globals()["TextComponent"] = TextComponentClass
|
||||
globals()["ImageComponent"] = ImageComponentClass
|
||||
globals()["EmojiComponent"] = EmojiComponentClass
|
||||
globals()["VoiceComponent"] = VoiceComponentClass
|
||||
globals()["AtComponent"] = AtComponentClass
|
||||
globals()["ReplyComponent"] = ReplyComponentClass
|
||||
globals()["ForwardNodeComponent"] = ForwardNodeComponentClass
|
||||
globals()["MessageSequence"] = MessageSequenceClass
|
||||
globals()["ForwardComponent"] = ForwardComponentClass
|
||||
return message_module
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_text(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [TextComponent("Hello,"), TextComponent("world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ImageComponent(binary_hash="image_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[发了一张图片,网卡了加载不出来] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_emoji(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [EmojiComponent(binary_hash="emoji_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[发了一个表情,网卡了加载不出来] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [VoiceComponent(binary_hash="voice_hash"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[语音消息,转录失败] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_at_component(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [AtComponent(target_user_id="114514"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "@114514 Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_fail_to_fetch(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_success(monkeypatch):
|
||||
module_msg = load_message_via_file(monkeypatch)
|
||||
|
||||
class DummyDBSessionWithReply(DummyDBSession):
|
||||
def exec(self, s):
|
||||
return self
|
||||
|
||||
def first(inner_self):
|
||||
class DummyRecord:
|
||||
processed_plain_text = "原消息内容"
|
||||
user_cardname = "cardname123"
|
||||
user_nickname = "nickname123"
|
||||
user_id = "userid123"
|
||||
|
||||
return DummyRecord()
|
||||
|
||||
module_msg.get_db_session = lambda: DummyDBSessionWithReply()
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了cardname123的消息: 原消息内容] Hello, world!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_component_with_db_fail(monkeypatch):
|
||||
module_msg = load_message_via_file(monkeypatch)
|
||||
|
||||
class DummyDBSessionWithError(DummyDBSession):
|
||||
def exec(self, s):
|
||||
raise Exception("数据库查询失败")
|
||||
|
||||
module_msg.get_db_session = lambda: DummyDBSessionWithError()
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [ReplyComponent(target_message_id="1919810"), TextComponent("Hello, world!")]
|
||||
await msg.process()
|
||||
assert msg.processed_plain_text == "[回复了一条消息,但原消息已无法访问] Hello, world!"
|
||||
assert any("数据库查询失败" in log for log in module_msg.logger.logging_record)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_component(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[TextComponent("转发消息2")],
|
||||
),
|
||||
]
|
||||
),
|
||||
TextComponent("Hello, world!"),
|
||||
]
|
||||
await msg.process()
|
||||
print("Processed plain text:", msg.processed_plain_text)
|
||||
expected_forward_text = """【合并转发消息:
|
||||
-- 【cardname1】: 转发消息1
|
||||
-- 【cardname2】: 转发消息2
|
||||
】 Hello, world!"""
|
||||
assert msg.processed_plain_text == expected_forward_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_with_reply(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
||||
),
|
||||
]
|
||||
),
|
||||
TextComponent("Hello, world!"),
|
||||
]
|
||||
await msg.process()
|
||||
assert (
|
||||
msg.processed_plain_text
|
||||
== """【合并转发消息:
|
||||
-- 【cardname1】: 转发消息1
|
||||
-- 【cardname2】: [回复了cardname1的消息: 转发消息1] 转发消息2
|
||||
】 Hello, world!"""
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_reply_with_delay_in_forward(monkeypatch):
|
||||
load_message_via_file(monkeypatch)
|
||||
msg = SessionMessage("msg123", datetime.now())
|
||||
msg.session_id = "session123"
|
||||
msg.platform = "test_platform"
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
|
||||
async def delayed_get_voice_text(binary_data):
|
||||
await asyncio.sleep(0.5) # 模拟延迟
|
||||
return "这是语音转文本的结果"
|
||||
|
||||
sys.modules["src.common.utils.utils_voice"].get_voice_text = delayed_get_voice_text
|
||||
|
||||
msg.raw_message.components = [
|
||||
ForwardNodeComponent(
|
||||
forward_components=[
|
||||
ForwardComponent(
|
||||
message_id="msg1",
|
||||
user_id="user1",
|
||||
user_nickname="nickname1",
|
||||
user_cardname="cardname1",
|
||||
content=[VoiceComponent(binary_hash="voice_hash1"), TextComponent("转发消息1")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg2",
|
||||
user_id="user2",
|
||||
user_nickname="nickname2",
|
||||
user_cardname="cardname2",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息2")],
|
||||
),
|
||||
ForwardComponent(
|
||||
message_id="msg3",
|
||||
user_id="user3",
|
||||
user_nickname="nickname3",
|
||||
user_cardname="cardname3",
|
||||
content=[ReplyComponent(target_message_id="msg1"), TextComponent("转发消息3")],
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
await msg.process()
|
||||
expected_text = """【合并转发消息:
|
||||
-- 【cardname1】: [语音: 这是语音转文本的结果] 转发消息1
|
||||
-- 【cardname2】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息2
|
||||
-- 【cardname3】: [回复了cardname1的消息: [语音: 这是语音转文本的结果] 转发消息1] 转发消息3
|
||||
】"""
|
||||
assert msg.processed_plain_text == expected_text
|
||||
Reference in New Issue
Block a user