Files
mai-bot/agentlite/tests/unit/test_message.py
2026-04-03 22:15:53 +08:00

298 lines
10 KiB
Python

"""Unit tests for message types.
This module tests all message-related types including ContentPart,
Message, ToolCall, and their various subclasses.
"""
from __future__ import annotations
import pytest
from agentlite import (
ContentPart,
Message,
TextPart,
ImageURLPart,
AudioURLPart,
ToolCall,
ToolCallPart,
)
class TestContentPart:
"""Tests for ContentPart base class and registry."""
def test_content_part_registry_auto_registers_subclasses(self):
"""Test that ContentPart subclasses are auto-registered."""
# All defined subclasses should be in registry
assert "text" in ContentPart._ContentPart__content_part_registry
assert "image_url" in ContentPart._ContentPart__content_part_registry
assert "audio_url" in ContentPart._ContentPart__content_part_registry
def test_text_part_creation(self):
"""Test basic TextPart creation."""
part = TextPart(text="Hello, world!")
assert part.type == "text"
assert part.text == "Hello, world!"
def test_text_part_model_dump(self):
"""Test TextPart serialization."""
part = TextPart(text="Hello")
dumped = part.model_dump()
assert dumped == {"type": "text", "text": "Hello"}
def test_text_part_merge_success(self):
"""Test successful text merge during streaming."""
part1 = TextPart(text="Hello ")
part2 = TextPart(text="world!")
result = part1.merge_in_place(part2)
assert result is True
assert part1.text == "Hello world!"
def test_text_part_merge_failure(self):
"""Test merge failure with incompatible types."""
text_part = TextPart(text="Hello")
# Try to merge with non-TextPart
result = text_part.merge_in_place("not a part")
assert result is False
assert text_part.text == "Hello" # Unchanged
class TestImageURLPart:
"""Tests for ImageURLPart."""
def test_image_url_part_creation(self):
"""Test ImageURLPart creation."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.type == "image_url"
assert part.image_url.url == "https://example.com/image.png"
def test_image_url_part_with_detail(self):
"""Test ImageURLPart with detail parameter."""
part = ImageURLPart(
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
)
assert part.image_url.detail == "high"
def test_image_url_part_default_detail(self):
"""Test ImageURLPart default detail is None."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.image_url.detail is None
class TestAudioURLPart:
"""Tests for AudioURLPart."""
def test_audio_url_part_creation(self):
"""Test AudioURLPart creation."""
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
assert part.type == "audio_url"
assert part.audio_url.url == "https://example.com/audio.mp3"
class TestToolCall:
"""Tests for ToolCall."""
def test_tool_call_creation(self):
"""Test ToolCall creation."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
assert call.type == "function"
assert call.id == "call_123"
assert call.function.name == "add"
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_with_part(self):
"""Test ToolCall merging with ToolCallPart."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
)
part = ToolCallPart(arguments_part=', "b": 2}')
result = call.merge_in_place(part)
assert result is True
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_failure(self):
"""Test ToolCall merge failure with incompatible types."""
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
result = call.merge_in_place("not a part")
assert result is False
class TestToolCallPart:
"""Tests for ToolCallPart."""
def test_tool_call_part_creation(self):
"""Test ToolCallPart creation."""
part = ToolCallPart(arguments_part='{"a": 1}')
assert part.arguments_part == '{"a": 1}'
def test_tool_call_part_none(self):
"""Test ToolCallPart with None arguments."""
part = ToolCallPart(arguments_part=None)
assert part.arguments_part is None
def test_tool_call_part_merge(self):
"""Test ToolCallPart merging."""
part1 = ToolCallPart(arguments_part='{"a":')
part2 = ToolCallPart(arguments_part=" 1}")
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
def test_tool_call_part_merge_none(self):
"""Test ToolCallPart merge when self is None."""
part1 = ToolCallPart(arguments_part=None)
part2 = ToolCallPart(arguments_part='{"a": 1}')
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
class TestMessage:
"""Tests for Message."""
def test_message_string_content_coercion(self):
"""Test that string content is coerced to TextPart."""
msg = Message(role="user", content="Hello!")
assert len(msg.content) == 1
assert isinstance(msg.content[0], TextPart)
assert msg.content[0].text == "Hello!"
def test_message_part_content(self):
"""Test Message with ContentPart content."""
part = TextPart(text="Hello!")
msg = Message(role="user", content=part)
assert len(msg.content) == 1
assert msg.content[0].text == "Hello!"
def test_message_list_content(self):
"""Test Message with list of ContentParts."""
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
msg = Message(role="user", content=parts)
assert len(msg.content) == 2
def test_message_extract_text(self):
"""Test text extraction from message."""
msg = Message(role="user", content="Hello world!")
assert msg.extract_text() == "Hello world!"
def test_message_extract_text_with_separator(self):
"""Test text extraction with custom separator."""
parts = [TextPart(text="Hello"), TextPart(text="world!")]
msg = Message(role="user", content=parts)
assert msg.extract_text(sep=" ") == "Hello world!"
assert msg.extract_text(sep="-") == "Hello-world!"
def test_message_has_tool_calls_false(self):
"""Test has_tool_calls returns False when no tool calls."""
msg = Message(role="assistant", content="Hello!")
assert msg.has_tool_calls() is False
def test_message_has_tool_calls_true(self):
"""Test has_tool_calls returns True when tool calls present."""
tool_call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
)
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
assert msg.has_tool_calls() is True
def test_message_has_tool_calls_empty_list(self):
"""Test has_tool_calls with empty tool_calls list."""
msg = Message(role="assistant", content="Hello!", tool_calls=[])
assert msg.has_tool_calls() is False
def test_message_tool_response(self):
"""Test message with tool response."""
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
assert msg.role == "tool"
assert msg.tool_call_id == "call_123"
def test_message_serialization(self):
"""Test Message serialization with model_dump."""
msg = Message(role="user", content="Hello!")
dumped = msg.model_dump()
assert dumped["role"] == "user"
assert "content" in dumped
def test_message_all_roles(self):
"""Test Message creation with all valid roles."""
for role in ["system", "user", "assistant", "tool"]:
msg = Message(role=role, content="Test")
assert msg.role == role
class TestPolymorphicContentPart:
"""Tests for polymorphic ContentPart validation."""
def test_polymorphic_validation_text(self):
"""Test that text type validates to TextPart."""
data = {"type": "text", "text": "Hello"}
part = ContentPart.model_validate(data)
assert isinstance(part, TextPart)
assert part.text == "Hello"
def test_polymorphic_validation_image(self):
"""Test that image_url type validates to ImageURLPart."""
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
part = ContentPart.model_validate(data)
assert isinstance(part, ImageURLPart)
assert part.image_url.url == "https://example.com/image.png"
def test_polymorphic_validation_unknown_type(self):
"""Test validation with unknown type raises error."""
data = {"type": "unknown_type", "content": "test"}
with pytest.raises(ValueError, match="Unknown content part type"):
ContentPart.model_validate(data)
def test_polymorphic_validation_no_type(self):
"""Test validation without type raises error."""
data = {"content": "test"}
with pytest.raises(ValueError):
ContentPart.model_validate(data)
class TestMessageEdgeCases:
"""Tests for edge cases in Message handling."""
def test_empty_string_content(self):
"""Test Message with empty string content."""
msg = Message(role="user", content="")
assert msg.content[0].text == ""
def test_message_with_name(self):
"""Test Message with name field."""
msg = Message(role="user", content="Hello", name="user1")
assert msg.name == "user1"
def test_message_history_isolation(self):
"""Test that history modifications don't affect original."""
msg = Message(role="user", content="Hello")
# Modify the content list
msg.content.append(TextPart(text="Extra"))
# Original should be modified (it's the same object)
assert len(msg.content) == 2