feat: add a subagent frame

This commit is contained in:
tcmofashi
2026-04-03 22:15:53 +08:00
parent ce580d1f8b
commit 185361f2c3
72 changed files with 13062 additions and 0 deletions

View File

View File

@@ -0,0 +1,330 @@
"""Unit tests for configuration models.
This module tests all Pydantic configuration models including
ProviderConfig, ModelConfig, ToolConfig, and AgentConfig.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from agentlite import ProviderConfig, ModelConfig, AgentConfig
class TestProviderConfig:
"""Tests for ProviderConfig."""
def test_provider_config_valid(self):
"""Test valid ProviderConfig creation."""
config = ProviderConfig(
type="openai",
base_url="https://api.openai.com/v1",
api_key="sk-test123",
)
assert config.type == "openai"
assert config.base_url == "https://api.openai.com/v1"
assert config.api_key.get_secret_value() == "sk-test123"
def test_provider_config_default_type(self):
"""Test ProviderConfig with default type."""
config = ProviderConfig(
base_url="https://api.openai.com/v1",
api_key="sk-test",
)
assert config.type == "openai"
def test_provider_config_default_url(self):
"""Test ProviderConfig with default base_url."""
config = ProviderConfig(
api_key="sk-test",
)
assert config.base_url == "https://api.openai.com/v1"
def test_provider_config_invalid_url_http(self):
"""Test ProviderConfig with invalid URL scheme."""
with pytest.raises(ValidationError) as exc_info:
ProviderConfig(
type="openai",
base_url="ftp://invalid.com",
api_key="sk-test",
)
assert "base_url must start with http:// or https://" in str(exc_info.value)
def test_provider_config_invalid_url_no_scheme(self):
"""Test ProviderConfig with URL without scheme."""
with pytest.raises(ValidationError):
ProviderConfig(
base_url="api.openai.com/v1",
api_key="sk-test",
)
def test_provider_config_custom_headers(self):
"""Test ProviderConfig with custom headers."""
config = ProviderConfig(
api_key="sk-test",
headers={"X-Custom": "value"},
)
assert config.headers == {"X-Custom": "value"}
def test_provider_config_default_headers(self):
"""Test ProviderConfig default headers."""
config = ProviderConfig(api_key="sk-test")
assert config.headers == {}
def test_provider_config_timeout(self):
"""Test ProviderConfig timeout."""
config = ProviderConfig(
api_key="sk-test",
timeout=30.0,
)
assert config.timeout == 30.0
def test_provider_config_default_timeout(self):
"""Test ProviderConfig default timeout."""
config = ProviderConfig(api_key="sk-test")
assert config.timeout == 60.0
def test_provider_config_api_key_is_secret_str(self):
"""Test that api_key is stored as SecretStr."""
config = ProviderConfig(api_key="sk-secret")
# SecretStr should not expose value in repr/str
assert "sk-secret" not in str(config.api_key)
# But can get value explicitly
assert config.api_key.get_secret_value() == "sk-secret"
class TestModelConfig:
"""Tests for ModelConfig."""
def test_model_config_valid(self):
"""Test valid ModelConfig creation."""
config = ModelConfig(
provider="openai",
model="gpt-4",
)
assert config.provider == "openai"
assert config.model == "gpt-4"
def test_model_config_with_all_fields(self):
"""Test ModelConfig with all optional fields."""
config = ModelConfig(
provider="openai",
model="gpt-4",
max_tokens=1000,
temperature=0.7,
top_p=0.9,
capabilities={"streaming", "tool_calling"},
)
assert config.max_tokens == 1000
assert config.temperature == 0.7
assert config.top_p == 0.9
assert config.capabilities == {"streaming", "tool_calling"}
def test_model_config_empty_provider(self):
"""Test ModelConfig with empty provider."""
with pytest.raises(ValidationError) as exc_info:
ModelConfig(
provider="",
model="gpt-4",
)
assert "provider must not be empty" in str(exc_info.value)
def test_model_config_temperature_bounds(self):
"""Test ModelConfig temperature validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=0.0)
assert config.temperature == 0.0
# Valid: 2.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=2.0)
assert config.temperature == 2.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=-0.1)
# Invalid: > 2
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=2.1)
def test_model_config_top_p_bounds(self):
"""Test ModelConfig top_p validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=0.0)
assert config.top_p == 0.0
# Valid: 1.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=1.0)
assert config.top_p == 1.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=-0.1)
# Invalid: > 1
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=1.1)
def test_model_config_default_capabilities(self):
"""Test ModelConfig default capabilities."""
config = ModelConfig(provider="openai", model="gpt-4")
assert config.capabilities == set()
class TestAgentConfig:
"""Tests for AgentConfig."""
def test_agent_config_minimal(self):
"""Test AgentConfig with minimal required fields."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
assert config.name == "agent"
assert config.system_prompt == "You are a helpful assistant."
assert config.default_model == "default"
def test_agent_config_full(self):
"""Test AgentConfig with all fields."""
config = AgentConfig(
name="my_agent",
system_prompt="Custom system prompt",
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
max_history=50,
)
assert config.name == "my_agent"
assert config.system_prompt == "Custom system prompt"
assert config.default_model == "gpt4"
assert config.max_history == 50
def test_agent_config_missing_default_model(self):
"""Test AgentConfig with non-existent default_model."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="nonexistent",
)
assert "not found in models" in str(exc_info.value)
def test_agent_config_unknown_provider(self):
"""Test AgentConfig with model referencing unknown provider."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="unknown", model="gpt-4")},
)
assert "unknown provider" in str(exc_info.value)
def test_agent_config_get_provider_config(self):
"""Test get_provider_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config("gpt4")
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_default(self):
"""Test get_provider_config with default model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config()
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_not_found(self):
"""Test get_provider_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_provider_config("nonexistent")
def test_agent_config_get_model_config(self):
"""Test get_model_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config("gpt4")
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_default(self):
"""Test get_model_config with default."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config()
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_not_found(self):
"""Test get_model_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_model_config("nonexistent")
def test_agent_config_multiple_providers(self):
"""Test AgentConfig with multiple providers."""
config = AgentConfig(
providers={
"openai": ProviderConfig(api_key="sk-openai"),
"anthropic": ProviderConfig(
type="anthropic",
base_url="https://api.anthropic.com/v1",
api_key="sk-anthropic",
),
},
models={
"default": ModelConfig(provider="openai", model="gpt-4"),
"claude": ModelConfig(provider="anthropic", model="claude-3"),
},
)
assert len(config.providers) == 2
assert len(config.models) == 2
def test_agent_config_max_history_validation(self):
"""Test max_history validation."""
# Valid: min=1
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=1,
)
assert config.max_history == 1
# Invalid: 0
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=0,
)
# Invalid: negative
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=-1,
)

View File

@@ -0,0 +1,297 @@
"""Unit tests for message types.
This module tests all message-related types including ContentPart,
Message, ToolCall, and their various subclasses.
"""
from __future__ import annotations
import pytest
from agentlite import (
ContentPart,
Message,
TextPart,
ImageURLPart,
AudioURLPart,
ToolCall,
ToolCallPart,
)
class TestContentPart:
"""Tests for ContentPart base class and registry."""
def test_content_part_registry_auto_registers_subclasses(self):
"""Test that ContentPart subclasses are auto-registered."""
# All defined subclasses should be in registry
assert "text" in ContentPart._ContentPart__content_part_registry
assert "image_url" in ContentPart._ContentPart__content_part_registry
assert "audio_url" in ContentPart._ContentPart__content_part_registry
def test_text_part_creation(self):
"""Test basic TextPart creation."""
part = TextPart(text="Hello, world!")
assert part.type == "text"
assert part.text == "Hello, world!"
def test_text_part_model_dump(self):
"""Test TextPart serialization."""
part = TextPart(text="Hello")
dumped = part.model_dump()
assert dumped == {"type": "text", "text": "Hello"}
def test_text_part_merge_success(self):
"""Test successful text merge during streaming."""
part1 = TextPart(text="Hello ")
part2 = TextPart(text="world!")
result = part1.merge_in_place(part2)
assert result is True
assert part1.text == "Hello world!"
def test_text_part_merge_failure(self):
"""Test merge failure with incompatible types."""
text_part = TextPart(text="Hello")
# Try to merge with non-TextPart
result = text_part.merge_in_place("not a part")
assert result is False
assert text_part.text == "Hello" # Unchanged
class TestImageURLPart:
"""Tests for ImageURLPart."""
def test_image_url_part_creation(self):
"""Test ImageURLPart creation."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.type == "image_url"
assert part.image_url.url == "https://example.com/image.png"
def test_image_url_part_with_detail(self):
"""Test ImageURLPart with detail parameter."""
part = ImageURLPart(
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
)
assert part.image_url.detail == "high"
def test_image_url_part_default_detail(self):
"""Test ImageURLPart default detail is None."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.image_url.detail is None
class TestAudioURLPart:
"""Tests for AudioURLPart."""
def test_audio_url_part_creation(self):
"""Test AudioURLPart creation."""
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
assert part.type == "audio_url"
assert part.audio_url.url == "https://example.com/audio.mp3"
class TestToolCall:
"""Tests for ToolCall."""
def test_tool_call_creation(self):
"""Test ToolCall creation."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
assert call.type == "function"
assert call.id == "call_123"
assert call.function.name == "add"
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_with_part(self):
"""Test ToolCall merging with ToolCallPart."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
)
part = ToolCallPart(arguments_part=', "b": 2}')
result = call.merge_in_place(part)
assert result is True
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_failure(self):
"""Test ToolCall merge failure with incompatible types."""
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
result = call.merge_in_place("not a part")
assert result is False
class TestToolCallPart:
"""Tests for ToolCallPart."""
def test_tool_call_part_creation(self):
"""Test ToolCallPart creation."""
part = ToolCallPart(arguments_part='{"a": 1}')
assert part.arguments_part == '{"a": 1}'
def test_tool_call_part_none(self):
"""Test ToolCallPart with None arguments."""
part = ToolCallPart(arguments_part=None)
assert part.arguments_part is None
def test_tool_call_part_merge(self):
"""Test ToolCallPart merging."""
part1 = ToolCallPart(arguments_part='{"a":')
part2 = ToolCallPart(arguments_part=" 1}")
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
def test_tool_call_part_merge_none(self):
"""Test ToolCallPart merge when self is None."""
part1 = ToolCallPart(arguments_part=None)
part2 = ToolCallPart(arguments_part='{"a": 1}')
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
class TestMessage:
"""Tests for Message."""
def test_message_string_content_coercion(self):
"""Test that string content is coerced to TextPart."""
msg = Message(role="user", content="Hello!")
assert len(msg.content) == 1
assert isinstance(msg.content[0], TextPart)
assert msg.content[0].text == "Hello!"
def test_message_part_content(self):
"""Test Message with ContentPart content."""
part = TextPart(text="Hello!")
msg = Message(role="user", content=part)
assert len(msg.content) == 1
assert msg.content[0].text == "Hello!"
def test_message_list_content(self):
"""Test Message with list of ContentParts."""
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
msg = Message(role="user", content=parts)
assert len(msg.content) == 2
def test_message_extract_text(self):
"""Test text extraction from message."""
msg = Message(role="user", content="Hello world!")
assert msg.extract_text() == "Hello world!"
def test_message_extract_text_with_separator(self):
"""Test text extraction with custom separator."""
parts = [TextPart(text="Hello"), TextPart(text="world!")]
msg = Message(role="user", content=parts)
assert msg.extract_text(sep=" ") == "Hello world!"
assert msg.extract_text(sep="-") == "Hello-world!"
def test_message_has_tool_calls_false(self):
"""Test has_tool_calls returns False when no tool calls."""
msg = Message(role="assistant", content="Hello!")
assert msg.has_tool_calls() is False
def test_message_has_tool_calls_true(self):
"""Test has_tool_calls returns True when tool calls present."""
tool_call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
)
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
assert msg.has_tool_calls() is True
def test_message_has_tool_calls_empty_list(self):
"""Test has_tool_calls with empty tool_calls list."""
msg = Message(role="assistant", content="Hello!", tool_calls=[])
assert msg.has_tool_calls() is False
def test_message_tool_response(self):
"""Test message with tool response."""
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
assert msg.role == "tool"
assert msg.tool_call_id == "call_123"
def test_message_serialization(self):
"""Test Message serialization with model_dump."""
msg = Message(role="user", content="Hello!")
dumped = msg.model_dump()
assert dumped["role"] == "user"
assert "content" in dumped
def test_message_all_roles(self):
"""Test Message creation with all valid roles."""
for role in ["system", "user", "assistant", "tool"]:
msg = Message(role=role, content="Test")
assert msg.role == role
class TestPolymorphicContentPart:
"""Tests for polymorphic ContentPart validation."""
def test_polymorphic_validation_text(self):
"""Test that text type validates to TextPart."""
data = {"type": "text", "text": "Hello"}
part = ContentPart.model_validate(data)
assert isinstance(part, TextPart)
assert part.text == "Hello"
def test_polymorphic_validation_image(self):
"""Test that image_url type validates to ImageURLPart."""
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
part = ContentPart.model_validate(data)
assert isinstance(part, ImageURLPart)
assert part.image_url.url == "https://example.com/image.png"
def test_polymorphic_validation_unknown_type(self):
"""Test validation with unknown type raises error."""
data = {"type": "unknown_type", "content": "test"}
with pytest.raises(ValueError, match="Unknown content part type"):
ContentPart.model_validate(data)
def test_polymorphic_validation_no_type(self):
"""Test validation without type raises error."""
data = {"content": "test"}
with pytest.raises(ValueError):
ContentPart.model_validate(data)
class TestMessageEdgeCases:
"""Tests for edge cases in Message handling."""
def test_empty_string_content(self):
"""Test Message with empty string content."""
msg = Message(role="user", content="")
assert msg.content[0].text == ""
def test_message_with_name(self):
"""Test Message with name field."""
msg = Message(role="user", content="Hello", name="user1")
assert msg.name == "user1"
def test_message_history_isolation(self):
"""Test that history modifications don't affect original."""
msg = Message(role="user", content="Hello")
# Modify the content list
msg.content.append(TextPart(text="Extra"))
# Original should be modified (it's the same object)
assert len(msg.content) == 2

View File

@@ -0,0 +1,168 @@
"""Unit tests for provider protocol and exceptions.
This module tests the ChatProvider protocol, StreamedMessage protocol,
and all exception types.
"""
from __future__ import annotations
import pytest
from agentlite.provider import (
TokenUsage,
ChatProviderError,
APIConnectionError,
APITimeoutError,
APIStatusError,
APIEmptyResponseError,
ChatProvider,
StreamedMessage,
)
class TestTokenUsage:
"""Tests for TokenUsage."""
def test_token_usage_creation(self):
"""Test TokenUsage creation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.input_tokens == 100
assert usage.output_tokens == 50
assert usage.cached_tokens == 0 # Default
def test_token_usage_with_cached(self):
"""Test TokenUsage with cached tokens."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
assert usage.cached_tokens == 20
def test_token_usage_total(self):
"""Test total token calculation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.total == 150
def test_token_usage_total_with_cached(self):
"""Test total with cached tokens (not included in total)."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
# Total is input + output, cached is tracked separately
assert usage.total == 150
class TestChatProviderError:
"""Tests for ChatProviderError hierarchy."""
def test_base_error_creation(self):
"""Test base ChatProviderError creation."""
error = ChatProviderError("Something went wrong")
assert error.message == "Something went wrong"
assert str(error) == "Something went wrong"
def test_api_connection_error(self):
"""Test APIConnectionError creation."""
error = APIConnectionError("Connection failed")
assert isinstance(error, ChatProviderError)
assert error.message == "Connection failed"
def test_api_timeout_error(self):
"""Test APITimeoutError creation."""
error = APITimeoutError("Request timed out")
assert isinstance(error, ChatProviderError)
assert error.message == "Request timed out"
def test_api_status_error(self):
"""Test APIStatusError creation."""
error = APIStatusError(429, "Rate limit exceeded")
assert isinstance(error, ChatProviderError)
assert error.status_code == 429
assert error.message == "Rate limit exceeded"
def test_api_status_error_different_codes(self):
"""Test APIStatusError with different status codes."""
codes = [400, 401, 403, 404, 429, 500, 502, 503]
for code in codes:
error = APIStatusError(code, f"Error {code}")
assert error.status_code == code
def test_api_empty_response_error(self):
"""Test APIEmptyResponseError creation."""
error = APIEmptyResponseError("Empty response from API")
assert isinstance(error, ChatProviderError)
assert error.message == "Empty response from API"
def test_exception_hierarchy(self):
"""Test that all exceptions inherit from ChatProviderError."""
errors = [
APIConnectionError("test"),
APITimeoutError("test"),
APIStatusError(500, "test"),
APIEmptyResponseError("test"),
]
for error in errors:
assert isinstance(error, ChatProviderError)
class TestChatProviderProtocol:
"""Tests for ChatProvider protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that ChatProvider is runtime checkable."""
# ChatProvider should have @runtime_checkable
from typing import runtime_checkable
assert hasattr(ChatProvider, "__protocol_attrs__")
def test_mock_provider_implements_protocol(self, mock_provider):
"""Test that MockProvider implements ChatProvider."""
assert isinstance(mock_provider, ChatProvider)
def test_mock_provider_has_model_name(self, mock_provider):
"""Test that mock provider has model_name property."""
assert hasattr(mock_provider, "model_name")
assert isinstance(mock_provider.model_name, str)
def test_mock_provider_has_generate_method(self, mock_provider):
"""Test that mock provider has generate method."""
assert hasattr(mock_provider, "generate")
assert callable(mock_provider.generate)
class TestStreamedMessageProtocol:
"""Tests for StreamedMessage protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that StreamedMessage is runtime checkable."""
assert hasattr(StreamedMessage, "__protocol_attrs__")
def test_mock_streamed_message_implements_protocol(self):
"""Test that MockStreamedMessage implements StreamedMessage."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert isinstance(stream, StreamedMessage)
def test_streamed_message_has_id_property(self):
"""Test that streamed message has id property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "id")
assert stream.id == "mock-msg-123"
def test_streamed_message_has_usage_property(self):
"""Test that streamed message has usage property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "usage")
assert stream.usage is not None
assert isinstance(stream.usage, TokenUsage)
def test_streamed_message_is_async_iterable(self):
"""Test that streamed message is async iterable."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "__aiter__")

View File

@@ -0,0 +1,210 @@
"""Unit tests for tool decorator and CallableTool.
This module tests the @tool() decorator and related tool functionality.
"""
from __future__ import annotations
import pytest
from agentlite.tool import tool, CallableTool, ToolOk, ToolError
class TestToolDecorator:
"""Tests for the @tool() decorator."""
def test_tool_decorator_basic(self):
"""Test basic tool decorator functionality."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert isinstance(add, CallableTool)
assert add.name == "add"
assert add.description == "Add two numbers."
assert add.parameters["type"] == "object"
assert "a" in add.parameters["properties"]
assert "b" in add.parameters["properties"]
assert add.parameters["properties"]["a"]["type"] == "number"
assert add.parameters["properties"]["b"]["type"] == "number"
assert add.parameters["required"] == ["a", "b"]
def test_tool_decorator_with_default_params(self):
"""Test tool decorator with default parameters."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
assert greet.name == "greet"
assert "name" in greet.parameters["required"]
assert "greeting" not in greet.parameters["required"]
def test_tool_decorator_custom_name(self):
"""Test tool decorator with custom name."""
@tool(name="custom_add")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.name == "custom_add"
def test_tool_decorator_custom_description(self):
"""Test tool decorator with custom description."""
@tool(description="Custom description")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.description == "Custom description"
def test_tool_decorator_no_docstring(self):
"""Test tool decorator with no docstring."""
@tool()
async def no_doc(a: float) -> float:
return a
assert no_doc.description == "No description provided"
def test_tool_decorator_param_types(self):
"""Test tool decorator with various parameter types."""
@tool()
async def multi_types(
s: str,
i: int,
f: float,
b: bool,
) -> dict:
"""Multiple types."""
return {"s": s, "i": i, "f": f, "b": b}
props = multi_types.parameters["properties"]
assert props["s"]["type"] == "string"
assert props["i"]["type"] == "integer"
assert props["f"]["type"] == "number"
assert props["b"]["type"] == "boolean"
def test_tool_decorator_no_type_hints(self):
"""Test tool decorator with no type hints."""
@tool()
async def no_types(param) -> str:
"""No type hints."""
return str(param)
assert no_types.parameters["properties"]["param"]["type"] == "string"
class TestToolDecoratorExecution:
"""Tests for tool decorator execution."""
@pytest.mark.asyncio
async def test_tool_execution_success(self):
"""Test successful tool execution."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
result = await add(1.0, 2.0)
assert isinstance(result, ToolOk)
assert result.output == "3.0"
@pytest.mark.asyncio
async def test_tool_execution_error(self):
"""Test tool execution with error."""
@tool()
async def divide(a: float, b: float) -> float:
"""Divide two numbers."""
return a / b
result = await divide(1.0, 0.0)
assert isinstance(result, ToolError)
assert "division by zero" in result.message
@pytest.mark.asyncio
async def test_tool_execution_with_kwargs(self):
"""Test tool execution with keyword arguments."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
result = await greet(name="World", greeting="Hi")
assert isinstance(result, ToolOk)
assert result.output == "Hi, World!"
class TestToolDecoratorMemorixBug:
"""Tests for the specific bug reported by Memorix project."""
def test_tool_decorator_memorix_case(self):
"""Test the exact case from Memorix bug report.
This test verifies that the @tool() decorator works correctly
with async functions that have string and float parameters.
"""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
assert isinstance(add_memory, CallableTool)
assert add_memory.name == "add_memory"
assert add_memory.description == "存储记忆"
# Check parameters schema
params = add_memory.parameters
assert params["type"] == "object"
assert "content" in params["properties"]
assert "importance" in params["properties"]
assert params["properties"]["content"]["type"] == "string"
assert params["properties"]["importance"]["type"] == "number"
# content is required (no default), importance is optional
assert "content" in params["required"]
assert "importance" not in params["required"]
@pytest.mark.asyncio
async def test_tool_decorator_memorix_execution(self):
"""Test execution of the Memorix case."""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok", "content": content, "importance": importance}
result = await add_memory("test content", 0.8)
assert isinstance(result, ToolOk)
assert "ok" in result.output
def test_tool_decorator_can_be_used_in_agent(self):
"""Test that decorated tools can be used with Agent.
This is an integration-style test to ensure the decorated tool
has all required attributes for Agent usage.
"""
from agentlite import Agent, OpenAIProvider
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
# Verify the tool has the base property required by Agent
assert hasattr(add_memory, "base")
base_tool = add_memory.base
assert base_tool.name == "add_memory"
assert base_tool.description == "存储记忆"
assert base_tool.parameters == add_memory.parameters