feat: add a subagent frame
This commit is contained in:
168
agentlite/tests/unit/test_provider.py
Normal file
168
agentlite/tests/unit/test_provider.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Unit tests for provider protocol and exceptions.
|
||||
|
||||
This module tests the ChatProvider protocol, StreamedMessage protocol,
|
||||
and all exception types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite.provider import (
|
||||
TokenUsage,
|
||||
ChatProviderError,
|
||||
APIConnectionError,
|
||||
APITimeoutError,
|
||||
APIStatusError,
|
||||
APIEmptyResponseError,
|
||||
ChatProvider,
|
||||
StreamedMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenUsage:
|
||||
"""Tests for TokenUsage."""
|
||||
|
||||
def test_token_usage_creation(self):
|
||||
"""Test TokenUsage creation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.input_tokens == 100
|
||||
assert usage.output_tokens == 50
|
||||
assert usage.cached_tokens == 0 # Default
|
||||
|
||||
def test_token_usage_with_cached(self):
|
||||
"""Test TokenUsage with cached tokens."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
assert usage.cached_tokens == 20
|
||||
|
||||
def test_token_usage_total(self):
|
||||
"""Test total token calculation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.total == 150
|
||||
|
||||
def test_token_usage_total_with_cached(self):
|
||||
"""Test total with cached tokens (not included in total)."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
# Total is input + output, cached is tracked separately
|
||||
assert usage.total == 150
|
||||
|
||||
|
||||
class TestChatProviderError:
|
||||
"""Tests for ChatProviderError hierarchy."""
|
||||
|
||||
def test_base_error_creation(self):
|
||||
"""Test base ChatProviderError creation."""
|
||||
error = ChatProviderError("Something went wrong")
|
||||
assert error.message == "Something went wrong"
|
||||
assert str(error) == "Something went wrong"
|
||||
|
||||
def test_api_connection_error(self):
|
||||
"""Test APIConnectionError creation."""
|
||||
error = APIConnectionError("Connection failed")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Connection failed"
|
||||
|
||||
def test_api_timeout_error(self):
|
||||
"""Test APITimeoutError creation."""
|
||||
error = APITimeoutError("Request timed out")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Request timed out"
|
||||
|
||||
def test_api_status_error(self):
|
||||
"""Test APIStatusError creation."""
|
||||
error = APIStatusError(429, "Rate limit exceeded")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.status_code == 429
|
||||
assert error.message == "Rate limit exceeded"
|
||||
|
||||
def test_api_status_error_different_codes(self):
|
||||
"""Test APIStatusError with different status codes."""
|
||||
codes = [400, 401, 403, 404, 429, 500, 502, 503]
|
||||
for code in codes:
|
||||
error = APIStatusError(code, f"Error {code}")
|
||||
assert error.status_code == code
|
||||
|
||||
def test_api_empty_response_error(self):
|
||||
"""Test APIEmptyResponseError creation."""
|
||||
error = APIEmptyResponseError("Empty response from API")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Empty response from API"
|
||||
|
||||
def test_exception_hierarchy(self):
|
||||
"""Test that all exceptions inherit from ChatProviderError."""
|
||||
errors = [
|
||||
APIConnectionError("test"),
|
||||
APITimeoutError("test"),
|
||||
APIStatusError(500, "test"),
|
||||
APIEmptyResponseError("test"),
|
||||
]
|
||||
for error in errors:
|
||||
assert isinstance(error, ChatProviderError)
|
||||
|
||||
|
||||
class TestChatProviderProtocol:
|
||||
"""Tests for ChatProvider protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that ChatProvider is runtime checkable."""
|
||||
# ChatProvider should have @runtime_checkable
|
||||
from typing import runtime_checkable
|
||||
|
||||
assert hasattr(ChatProvider, "__protocol_attrs__")
|
||||
|
||||
def test_mock_provider_implements_protocol(self, mock_provider):
|
||||
"""Test that MockProvider implements ChatProvider."""
|
||||
assert isinstance(mock_provider, ChatProvider)
|
||||
|
||||
def test_mock_provider_has_model_name(self, mock_provider):
|
||||
"""Test that mock provider has model_name property."""
|
||||
assert hasattr(mock_provider, "model_name")
|
||||
assert isinstance(mock_provider.model_name, str)
|
||||
|
||||
def test_mock_provider_has_generate_method(self, mock_provider):
|
||||
"""Test that mock provider has generate method."""
|
||||
assert hasattr(mock_provider, "generate")
|
||||
assert callable(mock_provider.generate)
|
||||
|
||||
|
||||
class TestStreamedMessageProtocol:
|
||||
"""Tests for StreamedMessage protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that StreamedMessage is runtime checkable."""
|
||||
assert hasattr(StreamedMessage, "__protocol_attrs__")
|
||||
|
||||
def test_mock_streamed_message_implements_protocol(self):
|
||||
"""Test that MockStreamedMessage implements StreamedMessage."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert isinstance(stream, StreamedMessage)
|
||||
|
||||
def test_streamed_message_has_id_property(self):
|
||||
"""Test that streamed message has id property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "id")
|
||||
assert stream.id == "mock-msg-123"
|
||||
|
||||
def test_streamed_message_has_usage_property(self):
|
||||
"""Test that streamed message has usage property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "usage")
|
||||
assert stream.usage is not None
|
||||
assert isinstance(stream.usage, TokenUsage)
|
||||
|
||||
def test_streamed_message_is_async_iterable(self):
|
||||
"""Test that streamed message is async iterable."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "__aiter__")
|
||||
Reference in New Issue
Block a user