167 lines
6.0 KiB
Python
167 lines
6.0 KiB
Python
"""Unit tests for provider protocol and exceptions.
|
|
|
|
This module tests the ChatProvider protocol, StreamedMessage protocol,
|
|
and all exception types.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
from agentlite.provider import (
|
|
TokenUsage,
|
|
ChatProviderError,
|
|
APIConnectionError,
|
|
APITimeoutError,
|
|
APIStatusError,
|
|
APIEmptyResponseError,
|
|
ChatProvider,
|
|
StreamedMessage,
|
|
)
|
|
|
|
|
|
class TestTokenUsage:
|
|
"""Tests for TokenUsage."""
|
|
|
|
def test_token_usage_creation(self):
|
|
"""Test TokenUsage creation."""
|
|
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
|
assert usage.input_tokens == 100
|
|
assert usage.output_tokens == 50
|
|
assert usage.cached_tokens == 0 # Default
|
|
|
|
def test_token_usage_with_cached(self):
|
|
"""Test TokenUsage with cached tokens."""
|
|
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
|
assert usage.cached_tokens == 20
|
|
|
|
def test_token_usage_total(self):
|
|
"""Test total token calculation."""
|
|
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
|
assert usage.total == 150
|
|
|
|
def test_token_usage_total_with_cached(self):
|
|
"""Test total with cached tokens (not included in total)."""
|
|
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
|
# Total is input + output, cached is tracked separately
|
|
assert usage.total == 150
|
|
|
|
|
|
class TestChatProviderError:
|
|
"""Tests for ChatProviderError hierarchy."""
|
|
|
|
def test_base_error_creation(self):
|
|
"""Test base ChatProviderError creation."""
|
|
error = ChatProviderError("Something went wrong")
|
|
assert error.message == "Something went wrong"
|
|
assert str(error) == "Something went wrong"
|
|
|
|
def test_api_connection_error(self):
|
|
"""Test APIConnectionError creation."""
|
|
error = APIConnectionError("Connection failed")
|
|
assert isinstance(error, ChatProviderError)
|
|
assert error.message == "Connection failed"
|
|
|
|
def test_api_timeout_error(self):
|
|
"""Test APITimeoutError creation."""
|
|
error = APITimeoutError("Request timed out")
|
|
assert isinstance(error, ChatProviderError)
|
|
assert error.message == "Request timed out"
|
|
|
|
def test_api_status_error(self):
|
|
"""Test APIStatusError creation."""
|
|
error = APIStatusError(429, "Rate limit exceeded")
|
|
assert isinstance(error, ChatProviderError)
|
|
assert error.status_code == 429
|
|
assert error.message == "Rate limit exceeded"
|
|
|
|
def test_api_status_error_different_codes(self):
|
|
"""Test APIStatusError with different status codes."""
|
|
codes = [400, 401, 403, 404, 429, 500, 502, 503]
|
|
for code in codes:
|
|
error = APIStatusError(code, f"Error {code}")
|
|
assert error.status_code == code
|
|
|
|
def test_api_empty_response_error(self):
|
|
"""Test APIEmptyResponseError creation."""
|
|
error = APIEmptyResponseError("Empty response from API")
|
|
assert isinstance(error, ChatProviderError)
|
|
assert error.message == "Empty response from API"
|
|
|
|
def test_exception_hierarchy(self):
|
|
"""Test that all exceptions inherit from ChatProviderError."""
|
|
errors = [
|
|
APIConnectionError("test"),
|
|
APITimeoutError("test"),
|
|
APIStatusError(500, "test"),
|
|
APIEmptyResponseError("test"),
|
|
]
|
|
for error in errors:
|
|
assert isinstance(error, ChatProviderError)
|
|
|
|
|
|
class TestChatProviderProtocol:
|
|
"""Tests for ChatProvider protocol."""
|
|
|
|
def test_protocol_is_runtime_checkable(self):
|
|
"""Test that ChatProvider is runtime checkable."""
|
|
# ChatProvider should have @runtime_checkable
|
|
|
|
assert hasattr(ChatProvider, "__protocol_attrs__")
|
|
|
|
def test_mock_provider_implements_protocol(self, mock_provider):
|
|
"""Test that MockProvider implements ChatProvider."""
|
|
assert isinstance(mock_provider, ChatProvider)
|
|
|
|
def test_mock_provider_has_model_name(self, mock_provider):
|
|
"""Test that mock provider has model_name property."""
|
|
assert hasattr(mock_provider, "model_name")
|
|
assert isinstance(mock_provider.model_name, str)
|
|
|
|
def test_mock_provider_has_generate_method(self, mock_provider):
|
|
"""Test that mock provider has generate method."""
|
|
assert hasattr(mock_provider, "generate")
|
|
assert callable(mock_provider.generate)
|
|
|
|
|
|
class TestStreamedMessageProtocol:
|
|
"""Tests for StreamedMessage protocol."""
|
|
|
|
def test_protocol_is_runtime_checkable(self):
|
|
"""Test that StreamedMessage is runtime checkable."""
|
|
assert hasattr(StreamedMessage, "__protocol_attrs__")
|
|
|
|
def test_mock_streamed_message_implements_protocol(self):
|
|
"""Test that MockStreamedMessage implements StreamedMessage."""
|
|
from tests.conftest import MockStreamedMessage
|
|
from agentlite import TextPart
|
|
|
|
stream = MockStreamedMessage([TextPart(text="Hello")])
|
|
assert isinstance(stream, StreamedMessage)
|
|
|
|
def test_streamed_message_has_id_property(self):
|
|
"""Test that streamed message has id property."""
|
|
from tests.conftest import MockStreamedMessage
|
|
from agentlite import TextPart
|
|
|
|
stream = MockStreamedMessage([TextPart(text="Hello")])
|
|
assert hasattr(stream, "id")
|
|
assert stream.id == "mock-msg-123"
|
|
|
|
def test_streamed_message_has_usage_property(self):
|
|
"""Test that streamed message has usage property."""
|
|
from tests.conftest import MockStreamedMessage
|
|
from agentlite import TextPart
|
|
|
|
stream = MockStreamedMessage([TextPart(text="Hello")])
|
|
assert hasattr(stream, "usage")
|
|
assert stream.usage is not None
|
|
assert isinstance(stream.usage, TokenUsage)
|
|
|
|
def test_streamed_message_is_async_iterable(self):
|
|
"""Test that streamed message is async iterable."""
|
|
from tests.conftest import MockStreamedMessage
|
|
from agentlite import TextPart
|
|
|
|
stream = MockStreamedMessage([TextPart(text="Hello")])
|
|
assert hasattr(stream, "__aiter__")
|