Files
mai-bot/agentlite/tests/unit/test_provider.py
2026-04-03 22:15:53 +08:00

169 lines
6.0 KiB
Python

"""Unit tests for provider protocol and exceptions.
This module tests the ChatProvider protocol, StreamedMessage protocol,
and all exception types.
"""
from __future__ import annotations
import pytest
from agentlite.provider import (
TokenUsage,
ChatProviderError,
APIConnectionError,
APITimeoutError,
APIStatusError,
APIEmptyResponseError,
ChatProvider,
StreamedMessage,
)
class TestTokenUsage:
"""Tests for TokenUsage."""
def test_token_usage_creation(self):
"""Test TokenUsage creation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.input_tokens == 100
assert usage.output_tokens == 50
assert usage.cached_tokens == 0 # Default
def test_token_usage_with_cached(self):
"""Test TokenUsage with cached tokens."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
assert usage.cached_tokens == 20
def test_token_usage_total(self):
"""Test total token calculation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.total == 150
def test_token_usage_total_with_cached(self):
"""Test total with cached tokens (not included in total)."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
# Total is input + output, cached is tracked separately
assert usage.total == 150
class TestChatProviderError:
"""Tests for ChatProviderError hierarchy."""
def test_base_error_creation(self):
"""Test base ChatProviderError creation."""
error = ChatProviderError("Something went wrong")
assert error.message == "Something went wrong"
assert str(error) == "Something went wrong"
def test_api_connection_error(self):
"""Test APIConnectionError creation."""
error = APIConnectionError("Connection failed")
assert isinstance(error, ChatProviderError)
assert error.message == "Connection failed"
def test_api_timeout_error(self):
"""Test APITimeoutError creation."""
error = APITimeoutError("Request timed out")
assert isinstance(error, ChatProviderError)
assert error.message == "Request timed out"
def test_api_status_error(self):
"""Test APIStatusError creation."""
error = APIStatusError(429, "Rate limit exceeded")
assert isinstance(error, ChatProviderError)
assert error.status_code == 429
assert error.message == "Rate limit exceeded"
def test_api_status_error_different_codes(self):
"""Test APIStatusError with different status codes."""
codes = [400, 401, 403, 404, 429, 500, 502, 503]
for code in codes:
error = APIStatusError(code, f"Error {code}")
assert error.status_code == code
def test_api_empty_response_error(self):
"""Test APIEmptyResponseError creation."""
error = APIEmptyResponseError("Empty response from API")
assert isinstance(error, ChatProviderError)
assert error.message == "Empty response from API"
def test_exception_hierarchy(self):
"""Test that all exceptions inherit from ChatProviderError."""
errors = [
APIConnectionError("test"),
APITimeoutError("test"),
APIStatusError(500, "test"),
APIEmptyResponseError("test"),
]
for error in errors:
assert isinstance(error, ChatProviderError)
class TestChatProviderProtocol:
"""Tests for ChatProvider protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that ChatProvider is runtime checkable."""
# ChatProvider should have @runtime_checkable
from typing import runtime_checkable
assert hasattr(ChatProvider, "__protocol_attrs__")
def test_mock_provider_implements_protocol(self, mock_provider):
"""Test that MockProvider implements ChatProvider."""
assert isinstance(mock_provider, ChatProvider)
def test_mock_provider_has_model_name(self, mock_provider):
"""Test that mock provider has model_name property."""
assert hasattr(mock_provider, "model_name")
assert isinstance(mock_provider.model_name, str)
def test_mock_provider_has_generate_method(self, mock_provider):
"""Test that mock provider has generate method."""
assert hasattr(mock_provider, "generate")
assert callable(mock_provider.generate)
class TestStreamedMessageProtocol:
"""Tests for StreamedMessage protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that StreamedMessage is runtime checkable."""
assert hasattr(StreamedMessage, "__protocol_attrs__")
def test_mock_streamed_message_implements_protocol(self):
"""Test that MockStreamedMessage implements StreamedMessage."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert isinstance(stream, StreamedMessage)
def test_streamed_message_has_id_property(self):
"""Test that streamed message has id property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "id")
assert stream.id == "mock-msg-123"
def test_streamed_message_has_usage_property(self):
"""Test that streamed message has usage property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "usage")
assert stream.usage is not None
assert isinstance(stream.usage, TokenUsage)
def test_streamed_message_is_async_iterable(self):
"""Test that streamed message is async iterable."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "__aiter__")