feat: add a subagent frame
This commit is contained in:
0
agentlite/tests/unit/__init__.py
Normal file
0
agentlite/tests/unit/__init__.py
Normal file
330
agentlite/tests/unit/test_config.py
Normal file
330
agentlite/tests/unit/test_config.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Unit tests for configuration models.
|
||||
|
||||
This module tests all Pydantic configuration models including
|
||||
ProviderConfig, ModelConfig, ToolConfig, and AgentConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from agentlite import ProviderConfig, ModelConfig, AgentConfig
|
||||
|
||||
|
||||
class TestProviderConfig:
|
||||
"""Tests for ProviderConfig."""
|
||||
|
||||
def test_provider_config_valid(self):
|
||||
"""Test valid ProviderConfig creation."""
|
||||
config = ProviderConfig(
|
||||
type="openai",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="sk-test123",
|
||||
)
|
||||
assert config.type == "openai"
|
||||
assert config.base_url == "https://api.openai.com/v1"
|
||||
assert config.api_key.get_secret_value() == "sk-test123"
|
||||
|
||||
def test_provider_config_default_type(self):
|
||||
"""Test ProviderConfig with default type."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert config.type == "openai"
|
||||
|
||||
def test_provider_config_default_url(self):
|
||||
"""Test ProviderConfig with default base_url."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert config.base_url == "https://api.openai.com/v1"
|
||||
|
||||
def test_provider_config_invalid_url_http(self):
|
||||
"""Test ProviderConfig with invalid URL scheme."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProviderConfig(
|
||||
type="openai",
|
||||
base_url="ftp://invalid.com",
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert "base_url must start with http:// or https://" in str(exc_info.value)
|
||||
|
||||
def test_provider_config_invalid_url_no_scheme(self):
|
||||
"""Test ProviderConfig with URL without scheme."""
|
||||
with pytest.raises(ValidationError):
|
||||
ProviderConfig(
|
||||
base_url="api.openai.com/v1",
|
||||
api_key="sk-test",
|
||||
)
|
||||
|
||||
def test_provider_config_custom_headers(self):
|
||||
"""Test ProviderConfig with custom headers."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
headers={"X-Custom": "value"},
|
||||
)
|
||||
assert config.headers == {"X-Custom": "value"}
|
||||
|
||||
def test_provider_config_default_headers(self):
|
||||
"""Test ProviderConfig default headers."""
|
||||
config = ProviderConfig(api_key="sk-test")
|
||||
assert config.headers == {}
|
||||
|
||||
def test_provider_config_timeout(self):
|
||||
"""Test ProviderConfig timeout."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
timeout=30.0,
|
||||
)
|
||||
assert config.timeout == 30.0
|
||||
|
||||
def test_provider_config_default_timeout(self):
|
||||
"""Test ProviderConfig default timeout."""
|
||||
config = ProviderConfig(api_key="sk-test")
|
||||
assert config.timeout == 60.0
|
||||
|
||||
def test_provider_config_api_key_is_secret_str(self):
|
||||
"""Test that api_key is stored as SecretStr."""
|
||||
config = ProviderConfig(api_key="sk-secret")
|
||||
# SecretStr should not expose value in repr/str
|
||||
assert "sk-secret" not in str(config.api_key)
|
||||
# But can get value explicitly
|
||||
assert config.api_key.get_secret_value() == "sk-secret"
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Tests for ModelConfig."""
|
||||
|
||||
def test_model_config_valid(self):
|
||||
"""Test valid ModelConfig creation."""
|
||||
config = ModelConfig(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
)
|
||||
assert config.provider == "openai"
|
||||
assert config.model == "gpt-4"
|
||||
|
||||
def test_model_config_with_all_fields(self):
|
||||
"""Test ModelConfig with all optional fields."""
|
||||
config = ModelConfig(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
max_tokens=1000,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
capabilities={"streaming", "tool_calling"},
|
||||
)
|
||||
assert config.max_tokens == 1000
|
||||
assert config.temperature == 0.7
|
||||
assert config.top_p == 0.9
|
||||
assert config.capabilities == {"streaming", "tool_calling"}
|
||||
|
||||
def test_model_config_empty_provider(self):
|
||||
"""Test ModelConfig with empty provider."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ModelConfig(
|
||||
provider="",
|
||||
model="gpt-4",
|
||||
)
|
||||
assert "provider must not be empty" in str(exc_info.value)
|
||||
|
||||
def test_model_config_temperature_bounds(self):
|
||||
"""Test ModelConfig temperature validation bounds."""
|
||||
# Valid: 0.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", temperature=0.0)
|
||||
assert config.temperature == 0.0
|
||||
|
||||
# Valid: 2.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", temperature=2.0)
|
||||
assert config.temperature == 2.0
|
||||
|
||||
# Invalid: < 0
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", temperature=-0.1)
|
||||
|
||||
# Invalid: > 2
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", temperature=2.1)
|
||||
|
||||
def test_model_config_top_p_bounds(self):
|
||||
"""Test ModelConfig top_p validation bounds."""
|
||||
# Valid: 0.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", top_p=0.0)
|
||||
assert config.top_p == 0.0
|
||||
|
||||
# Valid: 1.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", top_p=1.0)
|
||||
assert config.top_p == 1.0
|
||||
|
||||
# Invalid: < 0
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", top_p=-0.1)
|
||||
|
||||
# Invalid: > 1
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", top_p=1.1)
|
||||
|
||||
def test_model_config_default_capabilities(self):
|
||||
"""Test ModelConfig default capabilities."""
|
||||
config = ModelConfig(provider="openai", model="gpt-4")
|
||||
assert config.capabilities == set()
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
"""Tests for AgentConfig."""
|
||||
|
||||
def test_agent_config_minimal(self):
|
||||
"""Test AgentConfig with minimal required fields."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
assert config.name == "agent"
|
||||
assert config.system_prompt == "You are a helpful assistant."
|
||||
assert config.default_model == "default"
|
||||
|
||||
def test_agent_config_full(self):
|
||||
"""Test AgentConfig with all fields."""
|
||||
config = AgentConfig(
|
||||
name="my_agent",
|
||||
system_prompt="Custom system prompt",
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
max_history=50,
|
||||
)
|
||||
assert config.name == "my_agent"
|
||||
assert config.system_prompt == "Custom system prompt"
|
||||
assert config.default_model == "gpt4"
|
||||
assert config.max_history == 50
|
||||
|
||||
def test_agent_config_missing_default_model(self):
|
||||
"""Test AgentConfig with non-existent default_model."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="nonexistent",
|
||||
)
|
||||
assert "not found in models" in str(exc_info.value)
|
||||
|
||||
def test_agent_config_unknown_provider(self):
|
||||
"""Test AgentConfig with model referencing unknown provider."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="unknown", model="gpt-4")},
|
||||
)
|
||||
assert "unknown provider" in str(exc_info.value)
|
||||
|
||||
def test_agent_config_get_provider_config(self):
|
||||
"""Test get_provider_config method."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
provider_config = config.get_provider_config("gpt4")
|
||||
assert provider_config.api_key.get_secret_value() == "sk-test"
|
||||
|
||||
def test_agent_config_get_provider_config_default(self):
|
||||
"""Test get_provider_config with default model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
provider_config = config.get_provider_config()
|
||||
assert provider_config.api_key.get_secret_value() == "sk-test"
|
||||
|
||||
def test_agent_config_get_provider_config_not_found(self):
|
||||
"""Test get_provider_config with non-existent model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
config.get_provider_config("nonexistent")
|
||||
|
||||
def test_agent_config_get_model_config(self):
|
||||
"""Test get_model_config method."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
model_config = config.get_model_config("gpt4")
|
||||
assert model_config.model == "gpt-4"
|
||||
|
||||
def test_agent_config_get_model_config_default(self):
|
||||
"""Test get_model_config with default."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
model_config = config.get_model_config()
|
||||
assert model_config.model == "gpt-4"
|
||||
|
||||
def test_agent_config_get_model_config_not_found(self):
|
||||
"""Test get_model_config with non-existent model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
config.get_model_config("nonexistent")
|
||||
|
||||
def test_agent_config_multiple_providers(self):
|
||||
"""Test AgentConfig with multiple providers."""
|
||||
config = AgentConfig(
|
||||
providers={
|
||||
"openai": ProviderConfig(api_key="sk-openai"),
|
||||
"anthropic": ProviderConfig(
|
||||
type="anthropic",
|
||||
base_url="https://api.anthropic.com/v1",
|
||||
api_key="sk-anthropic",
|
||||
),
|
||||
},
|
||||
models={
|
||||
"default": ModelConfig(provider="openai", model="gpt-4"),
|
||||
"claude": ModelConfig(provider="anthropic", model="claude-3"),
|
||||
},
|
||||
)
|
||||
|
||||
assert len(config.providers) == 2
|
||||
assert len(config.models) == 2
|
||||
|
||||
def test_agent_config_max_history_validation(self):
|
||||
"""Test max_history validation."""
|
||||
# Valid: min=1
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=1,
|
||||
)
|
||||
assert config.max_history == 1
|
||||
|
||||
# Invalid: 0
|
||||
with pytest.raises(ValidationError):
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=0,
|
||||
)
|
||||
|
||||
# Invalid: negative
|
||||
with pytest.raises(ValidationError):
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=-1,
|
||||
)
|
||||
297
agentlite/tests/unit/test_message.py
Normal file
297
agentlite/tests/unit/test_message.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""Unit tests for message types.
|
||||
|
||||
This module tests all message-related types including ContentPart,
|
||||
Message, ToolCall, and their various subclasses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import (
|
||||
ContentPart,
|
||||
Message,
|
||||
TextPart,
|
||||
ImageURLPart,
|
||||
AudioURLPart,
|
||||
ToolCall,
|
||||
ToolCallPart,
|
||||
)
|
||||
|
||||
|
||||
class TestContentPart:
|
||||
"""Tests for ContentPart base class and registry."""
|
||||
|
||||
def test_content_part_registry_auto_registers_subclasses(self):
|
||||
"""Test that ContentPart subclasses are auto-registered."""
|
||||
# All defined subclasses should be in registry
|
||||
assert "text" in ContentPart._ContentPart__content_part_registry
|
||||
assert "image_url" in ContentPart._ContentPart__content_part_registry
|
||||
assert "audio_url" in ContentPart._ContentPart__content_part_registry
|
||||
|
||||
def test_text_part_creation(self):
|
||||
"""Test basic TextPart creation."""
|
||||
part = TextPart(text="Hello, world!")
|
||||
assert part.type == "text"
|
||||
assert part.text == "Hello, world!"
|
||||
|
||||
def test_text_part_model_dump(self):
|
||||
"""Test TextPart serialization."""
|
||||
part = TextPart(text="Hello")
|
||||
dumped = part.model_dump()
|
||||
assert dumped == {"type": "text", "text": "Hello"}
|
||||
|
||||
def test_text_part_merge_success(self):
|
||||
"""Test successful text merge during streaming."""
|
||||
part1 = TextPart(text="Hello ")
|
||||
part2 = TextPart(text="world!")
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.text == "Hello world!"
|
||||
|
||||
def test_text_part_merge_failure(self):
|
||||
"""Test merge failure with incompatible types."""
|
||||
text_part = TextPart(text="Hello")
|
||||
|
||||
# Try to merge with non-TextPart
|
||||
result = text_part.merge_in_place("not a part")
|
||||
assert result is False
|
||||
assert text_part.text == "Hello" # Unchanged
|
||||
|
||||
|
||||
class TestImageURLPart:
|
||||
"""Tests for ImageURLPart."""
|
||||
|
||||
def test_image_url_part_creation(self):
|
||||
"""Test ImageURLPart creation."""
|
||||
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
|
||||
assert part.type == "image_url"
|
||||
assert part.image_url.url == "https://example.com/image.png"
|
||||
|
||||
def test_image_url_part_with_detail(self):
|
||||
"""Test ImageURLPart with detail parameter."""
|
||||
part = ImageURLPart(
|
||||
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
|
||||
)
|
||||
assert part.image_url.detail == "high"
|
||||
|
||||
def test_image_url_part_default_detail(self):
|
||||
"""Test ImageURLPart default detail is None."""
|
||||
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
|
||||
assert part.image_url.detail is None
|
||||
|
||||
|
||||
class TestAudioURLPart:
|
||||
"""Tests for AudioURLPart."""
|
||||
|
||||
def test_audio_url_part_creation(self):
|
||||
"""Test AudioURLPart creation."""
|
||||
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
|
||||
assert part.type == "audio_url"
|
||||
assert part.audio_url.url == "https://example.com/audio.mp3"
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Tests for ToolCall."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test ToolCall creation."""
|
||||
call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
|
||||
)
|
||||
assert call.type == "function"
|
||||
assert call.id == "call_123"
|
||||
assert call.function.name == "add"
|
||||
assert call.function.arguments == '{"a": 1, "b": 2}'
|
||||
|
||||
def test_tool_call_merge_with_part(self):
|
||||
"""Test ToolCall merging with ToolCallPart."""
|
||||
call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
|
||||
)
|
||||
part = ToolCallPart(arguments_part=', "b": 2}')
|
||||
|
||||
result = call.merge_in_place(part)
|
||||
|
||||
assert result is True
|
||||
assert call.function.arguments == '{"a": 1, "b": 2}'
|
||||
|
||||
def test_tool_call_merge_failure(self):
|
||||
"""Test ToolCall merge failure with incompatible types."""
|
||||
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
|
||||
|
||||
result = call.merge_in_place("not a part")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestToolCallPart:
|
||||
"""Tests for ToolCallPart."""
|
||||
|
||||
def test_tool_call_part_creation(self):
|
||||
"""Test ToolCallPart creation."""
|
||||
part = ToolCallPart(arguments_part='{"a": 1}')
|
||||
assert part.arguments_part == '{"a": 1}'
|
||||
|
||||
def test_tool_call_part_none(self):
|
||||
"""Test ToolCallPart with None arguments."""
|
||||
part = ToolCallPart(arguments_part=None)
|
||||
assert part.arguments_part is None
|
||||
|
||||
def test_tool_call_part_merge(self):
|
||||
"""Test ToolCallPart merging."""
|
||||
part1 = ToolCallPart(arguments_part='{"a":')
|
||||
part2 = ToolCallPart(arguments_part=" 1}")
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.arguments_part == '{"a": 1}'
|
||||
|
||||
def test_tool_call_part_merge_none(self):
|
||||
"""Test ToolCallPart merge when self is None."""
|
||||
part1 = ToolCallPart(arguments_part=None)
|
||||
part2 = ToolCallPart(arguments_part='{"a": 1}')
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.arguments_part == '{"a": 1}'
|
||||
|
||||
|
||||
class TestMessage:
|
||||
"""Tests for Message."""
|
||||
|
||||
def test_message_string_content_coercion(self):
|
||||
"""Test that string content is coerced to TextPart."""
|
||||
msg = Message(role="user", content="Hello!")
|
||||
|
||||
assert len(msg.content) == 1
|
||||
assert isinstance(msg.content[0], TextPart)
|
||||
assert msg.content[0].text == "Hello!"
|
||||
|
||||
def test_message_part_content(self):
|
||||
"""Test Message with ContentPart content."""
|
||||
part = TextPart(text="Hello!")
|
||||
msg = Message(role="user", content=part)
|
||||
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Hello!"
|
||||
|
||||
def test_message_list_content(self):
|
||||
"""Test Message with list of ContentParts."""
|
||||
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
|
||||
msg = Message(role="user", content=parts)
|
||||
|
||||
assert len(msg.content) == 2
|
||||
|
||||
def test_message_extract_text(self):
|
||||
"""Test text extraction from message."""
|
||||
msg = Message(role="user", content="Hello world!")
|
||||
|
||||
assert msg.extract_text() == "Hello world!"
|
||||
|
||||
def test_message_extract_text_with_separator(self):
|
||||
"""Test text extraction with custom separator."""
|
||||
parts = [TextPart(text="Hello"), TextPart(text="world!")]
|
||||
msg = Message(role="user", content=parts)
|
||||
|
||||
assert msg.extract_text(sep=" ") == "Hello world!"
|
||||
assert msg.extract_text(sep="-") == "Hello-world!"
|
||||
|
||||
def test_message_has_tool_calls_false(self):
|
||||
"""Test has_tool_calls returns False when no tool calls."""
|
||||
msg = Message(role="assistant", content="Hello!")
|
||||
assert msg.has_tool_calls() is False
|
||||
|
||||
def test_message_has_tool_calls_true(self):
|
||||
"""Test has_tool_calls returns True when tool calls present."""
|
||||
tool_call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
|
||||
)
|
||||
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
|
||||
assert msg.has_tool_calls() is True
|
||||
|
||||
def test_message_has_tool_calls_empty_list(self):
|
||||
"""Test has_tool_calls with empty tool_calls list."""
|
||||
msg = Message(role="assistant", content="Hello!", tool_calls=[])
|
||||
assert msg.has_tool_calls() is False
|
||||
|
||||
def test_message_tool_response(self):
|
||||
"""Test message with tool response."""
|
||||
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
|
||||
assert msg.role == "tool"
|
||||
assert msg.tool_call_id == "call_123"
|
||||
|
||||
def test_message_serialization(self):
|
||||
"""Test Message serialization with model_dump."""
|
||||
msg = Message(role="user", content="Hello!")
|
||||
dumped = msg.model_dump()
|
||||
|
||||
assert dumped["role"] == "user"
|
||||
assert "content" in dumped
|
||||
|
||||
def test_message_all_roles(self):
|
||||
"""Test Message creation with all valid roles."""
|
||||
for role in ["system", "user", "assistant", "tool"]:
|
||||
msg = Message(role=role, content="Test")
|
||||
assert msg.role == role
|
||||
|
||||
|
||||
class TestPolymorphicContentPart:
|
||||
"""Tests for polymorphic ContentPart validation."""
|
||||
|
||||
def test_polymorphic_validation_text(self):
|
||||
"""Test that text type validates to TextPart."""
|
||||
data = {"type": "text", "text": "Hello"}
|
||||
part = ContentPart.model_validate(data)
|
||||
|
||||
assert isinstance(part, TextPart)
|
||||
assert part.text == "Hello"
|
||||
|
||||
def test_polymorphic_validation_image(self):
|
||||
"""Test that image_url type validates to ImageURLPart."""
|
||||
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
|
||||
part = ContentPart.model_validate(data)
|
||||
|
||||
assert isinstance(part, ImageURLPart)
|
||||
assert part.image_url.url == "https://example.com/image.png"
|
||||
|
||||
def test_polymorphic_validation_unknown_type(self):
|
||||
"""Test validation with unknown type raises error."""
|
||||
data = {"type": "unknown_type", "content": "test"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown content part type"):
|
||||
ContentPart.model_validate(data)
|
||||
|
||||
def test_polymorphic_validation_no_type(self):
|
||||
"""Test validation without type raises error."""
|
||||
data = {"content": "test"}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ContentPart.model_validate(data)
|
||||
|
||||
|
||||
class TestMessageEdgeCases:
|
||||
"""Tests for edge cases in Message handling."""
|
||||
|
||||
def test_empty_string_content(self):
|
||||
"""Test Message with empty string content."""
|
||||
msg = Message(role="user", content="")
|
||||
assert msg.content[0].text == ""
|
||||
|
||||
def test_message_with_name(self):
|
||||
"""Test Message with name field."""
|
||||
msg = Message(role="user", content="Hello", name="user1")
|
||||
assert msg.name == "user1"
|
||||
|
||||
def test_message_history_isolation(self):
|
||||
"""Test that history modifications don't affect original."""
|
||||
msg = Message(role="user", content="Hello")
|
||||
|
||||
# Modify the content list
|
||||
msg.content.append(TextPart(text="Extra"))
|
||||
|
||||
# Original should be modified (it's the same object)
|
||||
assert len(msg.content) == 2
|
||||
168
agentlite/tests/unit/test_provider.py
Normal file
168
agentlite/tests/unit/test_provider.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Unit tests for provider protocol and exceptions.
|
||||
|
||||
This module tests the ChatProvider protocol, StreamedMessage protocol,
|
||||
and all exception types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite.provider import (
|
||||
TokenUsage,
|
||||
ChatProviderError,
|
||||
APIConnectionError,
|
||||
APITimeoutError,
|
||||
APIStatusError,
|
||||
APIEmptyResponseError,
|
||||
ChatProvider,
|
||||
StreamedMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenUsage:
|
||||
"""Tests for TokenUsage."""
|
||||
|
||||
def test_token_usage_creation(self):
|
||||
"""Test TokenUsage creation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.input_tokens == 100
|
||||
assert usage.output_tokens == 50
|
||||
assert usage.cached_tokens == 0 # Default
|
||||
|
||||
def test_token_usage_with_cached(self):
|
||||
"""Test TokenUsage with cached tokens."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
assert usage.cached_tokens == 20
|
||||
|
||||
def test_token_usage_total(self):
|
||||
"""Test total token calculation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.total == 150
|
||||
|
||||
def test_token_usage_total_with_cached(self):
|
||||
"""Test total with cached tokens (not included in total)."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
# Total is input + output, cached is tracked separately
|
||||
assert usage.total == 150
|
||||
|
||||
|
||||
class TestChatProviderError:
|
||||
"""Tests for ChatProviderError hierarchy."""
|
||||
|
||||
def test_base_error_creation(self):
|
||||
"""Test base ChatProviderError creation."""
|
||||
error = ChatProviderError("Something went wrong")
|
||||
assert error.message == "Something went wrong"
|
||||
assert str(error) == "Something went wrong"
|
||||
|
||||
def test_api_connection_error(self):
|
||||
"""Test APIConnectionError creation."""
|
||||
error = APIConnectionError("Connection failed")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Connection failed"
|
||||
|
||||
def test_api_timeout_error(self):
|
||||
"""Test APITimeoutError creation."""
|
||||
error = APITimeoutError("Request timed out")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Request timed out"
|
||||
|
||||
def test_api_status_error(self):
|
||||
"""Test APIStatusError creation."""
|
||||
error = APIStatusError(429, "Rate limit exceeded")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.status_code == 429
|
||||
assert error.message == "Rate limit exceeded"
|
||||
|
||||
def test_api_status_error_different_codes(self):
|
||||
"""Test APIStatusError with different status codes."""
|
||||
codes = [400, 401, 403, 404, 429, 500, 502, 503]
|
||||
for code in codes:
|
||||
error = APIStatusError(code, f"Error {code}")
|
||||
assert error.status_code == code
|
||||
|
||||
def test_api_empty_response_error(self):
|
||||
"""Test APIEmptyResponseError creation."""
|
||||
error = APIEmptyResponseError("Empty response from API")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Empty response from API"
|
||||
|
||||
def test_exception_hierarchy(self):
|
||||
"""Test that all exceptions inherit from ChatProviderError."""
|
||||
errors = [
|
||||
APIConnectionError("test"),
|
||||
APITimeoutError("test"),
|
||||
APIStatusError(500, "test"),
|
||||
APIEmptyResponseError("test"),
|
||||
]
|
||||
for error in errors:
|
||||
assert isinstance(error, ChatProviderError)
|
||||
|
||||
|
||||
class TestChatProviderProtocol:
|
||||
"""Tests for ChatProvider protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that ChatProvider is runtime checkable."""
|
||||
# ChatProvider should have @runtime_checkable
|
||||
from typing import runtime_checkable
|
||||
|
||||
assert hasattr(ChatProvider, "__protocol_attrs__")
|
||||
|
||||
def test_mock_provider_implements_protocol(self, mock_provider):
|
||||
"""Test that MockProvider implements ChatProvider."""
|
||||
assert isinstance(mock_provider, ChatProvider)
|
||||
|
||||
def test_mock_provider_has_model_name(self, mock_provider):
|
||||
"""Test that mock provider has model_name property."""
|
||||
assert hasattr(mock_provider, "model_name")
|
||||
assert isinstance(mock_provider.model_name, str)
|
||||
|
||||
def test_mock_provider_has_generate_method(self, mock_provider):
|
||||
"""Test that mock provider has generate method."""
|
||||
assert hasattr(mock_provider, "generate")
|
||||
assert callable(mock_provider.generate)
|
||||
|
||||
|
||||
class TestStreamedMessageProtocol:
|
||||
"""Tests for StreamedMessage protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that StreamedMessage is runtime checkable."""
|
||||
assert hasattr(StreamedMessage, "__protocol_attrs__")
|
||||
|
||||
def test_mock_streamed_message_implements_protocol(self):
|
||||
"""Test that MockStreamedMessage implements StreamedMessage."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert isinstance(stream, StreamedMessage)
|
||||
|
||||
def test_streamed_message_has_id_property(self):
|
||||
"""Test that streamed message has id property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "id")
|
||||
assert stream.id == "mock-msg-123"
|
||||
|
||||
def test_streamed_message_has_usage_property(self):
|
||||
"""Test that streamed message has usage property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "usage")
|
||||
assert stream.usage is not None
|
||||
assert isinstance(stream.usage, TokenUsage)
|
||||
|
||||
def test_streamed_message_is_async_iterable(self):
|
||||
"""Test that streamed message is async iterable."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "__aiter__")
|
||||
210
agentlite/tests/unit/test_tool.py
Normal file
210
agentlite/tests/unit/test_tool.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for tool decorator and CallableTool.
|
||||
|
||||
This module tests the @tool() decorator and related tool functionality.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite.tool import tool, CallableTool, ToolOk, ToolError
|
||||
|
||||
|
||||
class TestToolDecorator:
|
||||
"""Tests for the @tool() decorator."""
|
||||
|
||||
def test_tool_decorator_basic(self):
|
||||
"""Test basic tool decorator functionality."""
|
||||
|
||||
@tool()
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert isinstance(add, CallableTool)
|
||||
assert add.name == "add"
|
||||
assert add.description == "Add two numbers."
|
||||
assert add.parameters["type"] == "object"
|
||||
assert "a" in add.parameters["properties"]
|
||||
assert "b" in add.parameters["properties"]
|
||||
assert add.parameters["properties"]["a"]["type"] == "number"
|
||||
assert add.parameters["properties"]["b"]["type"] == "number"
|
||||
assert add.parameters["required"] == ["a", "b"]
|
||||
|
||||
def test_tool_decorator_with_default_params(self):
|
||||
"""Test tool decorator with default parameters."""
|
||||
|
||||
@tool()
|
||||
async def greet(name: str, greeting: str = "Hello") -> str:
|
||||
"""Greet someone."""
|
||||
return f"{greeting}, {name}!"
|
||||
|
||||
assert greet.name == "greet"
|
||||
assert "name" in greet.parameters["required"]
|
||||
assert "greeting" not in greet.parameters["required"]
|
||||
|
||||
def test_tool_decorator_custom_name(self):
|
||||
"""Test tool decorator with custom name."""
|
||||
|
||||
@tool(name="custom_add")
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add.name == "custom_add"
|
||||
|
||||
def test_tool_decorator_custom_description(self):
|
||||
"""Test tool decorator with custom description."""
|
||||
|
||||
@tool(description="Custom description")
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add.description == "Custom description"
|
||||
|
||||
def test_tool_decorator_no_docstring(self):
|
||||
"""Test tool decorator with no docstring."""
|
||||
|
||||
@tool()
|
||||
async def no_doc(a: float) -> float:
|
||||
return a
|
||||
|
||||
assert no_doc.description == "No description provided"
|
||||
|
||||
def test_tool_decorator_param_types(self):
|
||||
"""Test tool decorator with various parameter types."""
|
||||
|
||||
@tool()
|
||||
async def multi_types(
|
||||
s: str,
|
||||
i: int,
|
||||
f: float,
|
||||
b: bool,
|
||||
) -> dict:
|
||||
"""Multiple types."""
|
||||
return {"s": s, "i": i, "f": f, "b": b}
|
||||
|
||||
props = multi_types.parameters["properties"]
|
||||
assert props["s"]["type"] == "string"
|
||||
assert props["i"]["type"] == "integer"
|
||||
assert props["f"]["type"] == "number"
|
||||
assert props["b"]["type"] == "boolean"
|
||||
|
||||
def test_tool_decorator_no_type_hints(self):
|
||||
"""Test tool decorator with no type hints."""
|
||||
|
||||
@tool()
|
||||
async def no_types(param) -> str:
|
||||
"""No type hints."""
|
||||
return str(param)
|
||||
|
||||
assert no_types.parameters["properties"]["param"]["type"] == "string"
|
||||
|
||||
|
||||
class TestToolDecoratorExecution:
|
||||
"""Tests for tool decorator execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_success(self):
|
||||
"""Test successful tool execution."""
|
||||
|
||||
@tool()
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
result = await add(1.0, 2.0)
|
||||
assert isinstance(result, ToolOk)
|
||||
assert result.output == "3.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_error(self):
|
||||
"""Test tool execution with error."""
|
||||
|
||||
@tool()
|
||||
async def divide(a: float, b: float) -> float:
|
||||
"""Divide two numbers."""
|
||||
return a / b
|
||||
|
||||
result = await divide(1.0, 0.0)
|
||||
assert isinstance(result, ToolError)
|
||||
assert "division by zero" in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_kwargs(self):
|
||||
"""Test tool execution with keyword arguments."""
|
||||
|
||||
@tool()
|
||||
async def greet(name: str, greeting: str = "Hello") -> str:
|
||||
"""Greet someone."""
|
||||
return f"{greeting}, {name}!"
|
||||
|
||||
result = await greet(name="World", greeting="Hi")
|
||||
assert isinstance(result, ToolOk)
|
||||
assert result.output == "Hi, World!"
|
||||
|
||||
|
||||
class TestToolDecoratorMemorixBug:
|
||||
"""Tests for the specific bug reported by Memorix project."""
|
||||
|
||||
def test_tool_decorator_memorix_case(self):
|
||||
"""Test the exact case from Memorix bug report.
|
||||
|
||||
This test verifies that the @tool() decorator works correctly
|
||||
with async functions that have string and float parameters.
|
||||
"""
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok"}
|
||||
|
||||
assert isinstance(add_memory, CallableTool)
|
||||
assert add_memory.name == "add_memory"
|
||||
assert add_memory.description == "存储记忆"
|
||||
|
||||
# Check parameters schema
|
||||
params = add_memory.parameters
|
||||
assert params["type"] == "object"
|
||||
assert "content" in params["properties"]
|
||||
assert "importance" in params["properties"]
|
||||
assert params["properties"]["content"]["type"] == "string"
|
||||
assert params["properties"]["importance"]["type"] == "number"
|
||||
|
||||
# content is required (no default), importance is optional
|
||||
assert "content" in params["required"]
|
||||
assert "importance" not in params["required"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_decorator_memorix_execution(self):
|
||||
"""Test execution of the Memorix case."""
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok", "content": content, "importance": importance}
|
||||
|
||||
result = await add_memory("test content", 0.8)
|
||||
assert isinstance(result, ToolOk)
|
||||
assert "ok" in result.output
|
||||
|
||||
def test_tool_decorator_can_be_used_in_agent(self):
|
||||
"""Test that decorated tools can be used with Agent.
|
||||
|
||||
This is an integration-style test to ensure the decorated tool
|
||||
has all required attributes for Agent usage.
|
||||
"""
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok"}
|
||||
|
||||
# Verify the tool has the base property required by Agent
|
||||
assert hasattr(add_memory, "base")
|
||||
base_tool = add_memory.base
|
||||
assert base_tool.name == "add_memory"
|
||||
assert base_tool.description == "存储记忆"
|
||||
assert base_tool.parameters == add_memory.parameters
|
||||
Reference in New Issue
Block a user