feat: add a subagent frame
This commit is contained in:
0
agentlite/tests/__init__.py
Normal file
0
agentlite/tests/__init__.py
Normal file
331
agentlite/tests/conftest.py
Normal file
331
agentlite/tests/conftest.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Test configuration and shared fixtures for AgentLite tests.
|
||||
|
||||
This module provides pytest configuration and fixtures that are shared
|
||||
across all test modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import (
|
||||
Agent,
|
||||
ContentPart,
|
||||
Message,
|
||||
TextPart,
|
||||
ToolCall,
|
||||
ToolOk,
|
||||
ToolError,
|
||||
tool,
|
||||
)
|
||||
from agentlite.provider import ChatProvider, StreamedMessage, TokenUsage
|
||||
from agentlite.tool import Tool, ToolResult
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# pytest Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Configure pytest with custom markers."""
|
||||
config.addinivalue_line("markers", "unit: Unit tests")
|
||||
config.addinivalue_line("markers", "integration: Integration tests")
|
||||
config.addinivalue_line("markers", "scenario: Real-world scenario tests")
|
||||
config.addinivalue_line("markers", "slow: Slow tests that may take time")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mock Provider Implementation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MockStreamedMessage:
|
||||
"""Mock streamed message for testing."""
|
||||
|
||||
def __init__(self, parts: list[ContentPart]):
|
||||
self._parts = parts
|
||||
self._id = "mock-msg-123"
|
||||
self._usage = TokenUsage(input_tokens=10, output_tokens=5)
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[ContentPart]:
|
||||
"""Return async iterator over parts."""
|
||||
return self._iter_parts()
|
||||
|
||||
async def _iter_parts(self) -> AsyncIterator[ContentPart]:
|
||||
"""Iterate over parts."""
|
||||
for part in self._parts:
|
||||
yield part
|
||||
|
||||
@property
|
||||
def id(self) -> Optional[str]:
|
||||
"""Message ID."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def usage(self) -> Optional[TokenUsage]:
|
||||
"""Token usage."""
|
||||
return self._usage
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""Mock provider for testing AgentLite without real API calls.
|
||||
|
||||
This provider simulates OpenAI API responses and allows:
|
||||
- Configuring response sequences
|
||||
- Simulating tool calls
|
||||
- Simulating errors
|
||||
- Tracking all calls for verification
|
||||
|
||||
Example:
|
||||
provider = MockProvider()
|
||||
provider.add_text_response("Hello!")
|
||||
provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
|
||||
|
||||
agent = Agent(provider=provider)
|
||||
response = await agent.run("Hi")
|
||||
|
||||
# Verify calls
|
||||
assert len(provider.calls) == 1
|
||||
assert provider.calls[0]["system_prompt"] == "You are helpful."
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.responses: list[dict[str, Any]] = []
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self.model = "mock-model"
|
||||
|
||||
def add_text_response(self, text: str) -> None:
|
||||
"""Add a text response to the queue."""
|
||||
self.responses.append({"type": "text", "content": text})
|
||||
|
||||
def add_text_responses(self, *texts: str) -> None:
|
||||
"""Add multiple text responses to the queue."""
|
||||
for text in texts:
|
||||
self.add_text_response(text)
|
||||
|
||||
def add_tool_call(self, name: str, arguments: dict[str, Any], result: str) -> None:
|
||||
"""Add a tool call response to the queue."""
|
||||
self.responses.append(
|
||||
{"type": "tool_call", "name": name, "arguments": arguments, "result": result}
|
||||
)
|
||||
|
||||
def add_tool_calls(self, calls: list[dict[str, Any]]) -> None:
|
||||
"""Add multiple tool calls to the queue."""
|
||||
for call in calls:
|
||||
self.add_tool_call(call["name"], call["arguments"], call.get("result", ""))
|
||||
|
||||
def add_error(self, error: Exception) -> None:
|
||||
"""Add an error response to the queue."""
|
||||
self.responses.append({"type": "error", "error": error})
|
||||
|
||||
def clear_responses(self) -> None:
|
||||
"""Clear all pending responses."""
|
||||
self.responses.clear()
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Model name."""
|
||||
return self.model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
system_prompt: str,
|
||||
tools: Sequence[Tool],
|
||||
history: Sequence[Message],
|
||||
) -> StreamedMessage:
|
||||
"""Generate a mock response."""
|
||||
self.calls.append(
|
||||
{
|
||||
"system_prompt": system_prompt,
|
||||
"tools": list(tools),
|
||||
"history": list(history),
|
||||
}
|
||||
)
|
||||
|
||||
if not self.responses:
|
||||
return MockStreamedMessage([TextPart(text="Mock response")])
|
||||
|
||||
response = self.responses.pop(0)
|
||||
|
||||
if response["type"] == "error":
|
||||
raise response["error"]
|
||||
elif response["type"] == "text":
|
||||
return MockStreamedMessage([TextPart(text=response["content"])])
|
||||
elif response["type"] == "tool_call":
|
||||
return MockStreamedMessage(
|
||||
[
|
||||
ToolCall(
|
||||
id="call_123",
|
||||
function=ToolCall.FunctionBody(
|
||||
name=response["name"], arguments=json.dumps(response["arguments"])
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return MockStreamedMessage([TextPart(text="Unknown response type")])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
"""Create a mock provider with no responses configured."""
|
||||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_with_response():
|
||||
"""Create a mock provider that returns a simple text response."""
|
||||
provider = MockProvider()
|
||||
provider.add_text_response("Hello!")
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_with_sequence():
|
||||
"""Create a mock provider with multiple responses configured."""
|
||||
provider = MockProvider()
|
||||
provider.add_text_responses("Response 1", "Response 2", "Response 3")
|
||||
return provider
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Message Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text_message():
|
||||
"""Create a sample text message."""
|
||||
return Message(role="user", content="Hello!")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_assistant_message():
|
||||
"""Create a sample assistant message."""
|
||||
return Message(role="assistant", content="Hi there!")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_call():
|
||||
"""Create a sample tool call."""
|
||||
return ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tool_message():
|
||||
"""Create a sample tool response message."""
|
||||
return Message(role="tool", content="3", tool_call_id="call_123")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tool Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_tool():
|
||||
"""Create a simple add tool."""
|
||||
|
||||
@tool()
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
return add
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiply_tool():
|
||||
"""Create a multiply tool."""
|
||||
|
||||
@tool()
|
||||
async def multiply(a: float, b: float) -> float:
|
||||
"""Multiply two numbers."""
|
||||
return a * b
|
||||
|
||||
return multiply
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def error_tool():
|
||||
"""Create a tool that raises an error."""
|
||||
|
||||
@tool()
|
||||
async def error() -> str:
|
||||
"""Always raises an error."""
|
||||
raise ValueError("Test error")
|
||||
|
||||
return error
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def slow_tool():
|
||||
"""Create a tool that takes some time."""
|
||||
|
||||
@tool()
|
||||
async def slow_operation(duration: float = 0.1) -> str:
|
||||
"""Simulate a slow operation."""
|
||||
await asyncio.sleep(duration)
|
||||
return f"Completed after {duration}s"
|
||||
|
||||
return slow_operation
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Agent Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def simple_agent(mock_provider):
|
||||
"""Create a simple agent with mocked provider."""
|
||||
return Agent(provider=mock_provider)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_with_tools(mock_provider, add_tool):
|
||||
"""Create an agent with tools."""
|
||||
return Agent(provider=mock_provider, tools=[add_tool])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def agent_with_multiple_tools(mock_provider, add_tool, multiply_tool):
|
||||
"""Create an agent with multiple tools."""
|
||||
return Agent(provider=mock_provider, tools=[add_tool, multiply_tool])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utility Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation():
|
||||
"""Create a sample conversation history."""
|
||||
return [
|
||||
Message(role="user", content="Hello!"),
|
||||
Message(role="assistant", content="Hi there! How can I help?"),
|
||||
Message(role="user", content="What is 2+2?"),
|
||||
Message(role="assistant", content="2+2=4"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
0
agentlite/tests/integration/__init__.py
Normal file
0
agentlite/tests/integration/__init__.py
Normal file
286
agentlite/tests/integration/test_agent.py
Normal file
286
agentlite/tests/integration/test_agent.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Integration tests for Agent class.
|
||||
|
||||
This module tests the Agent class with mocked providers to verify
|
||||
core functionality without making real API calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import Agent, TextPart
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentInitialization:
|
||||
"""Tests for Agent initialization."""
|
||||
|
||||
def test_agent_initialization(self, mock_provider):
|
||||
"""Test basic agent creation."""
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
assert agent.provider is mock_provider
|
||||
assert agent.system_prompt == "You are a helpful assistant."
|
||||
assert agent.max_iterations == 80
|
||||
assert agent.history == []
|
||||
|
||||
def test_agent_with_custom_system_prompt(self, mock_provider):
|
||||
"""Test agent creation with custom system prompt."""
|
||||
agent = Agent(provider=mock_provider, system_prompt="You are a specialized assistant.")
|
||||
|
||||
assert agent.system_prompt == "You are a specialized assistant."
|
||||
|
||||
def test_agent_with_tools(self, mock_provider, add_tool):
|
||||
"""Test agent creation with tools."""
|
||||
agent = Agent(provider=mock_provider, tools=[add_tool])
|
||||
|
||||
assert len(agent.tools.tools) == 1
|
||||
assert agent.tools.tools[0].name == "add"
|
||||
|
||||
def test_agent_with_custom_max_iterations(self, mock_provider):
|
||||
"""Test agent with custom max_iterations."""
|
||||
agent = Agent(provider=mock_provider, max_iterations=5)
|
||||
|
||||
assert agent.max_iterations == 5
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentRun:
|
||||
"""Tests for Agent.run() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_run_simple(self, mock_provider):
|
||||
"""Test simple non-streaming run."""
|
||||
mock_provider.add_text_response("Hello there!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
response = await agent.run("Hi")
|
||||
|
||||
assert response == "Hello there!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_run_adds_to_history(self, mock_provider):
|
||||
"""Test that run adds messages to history."""
|
||||
mock_provider.add_text_response("Response!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
# History should have user message and assistant response
|
||||
assert len(agent.history) == 2
|
||||
assert agent.history[0].role == "user"
|
||||
assert agent.history[0].extract_text() == "Hello"
|
||||
assert agent.history[1].role == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_run_multiple_messages(self, mock_provider):
|
||||
"""Test multiple runs accumulate history."""
|
||||
mock_provider.add_text_responses("Response 1", "Response 2")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.run("Message 1")
|
||||
await agent.run("Message 2")
|
||||
|
||||
# Should have 4 messages total
|
||||
assert len(agent.history) == 4
|
||||
assert agent.history[0].role == "user"
|
||||
assert agent.history[1].role == "assistant"
|
||||
assert agent.history[2].role == "user"
|
||||
assert agent.history[3].role == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_run_tracks_calls(self, mock_provider):
|
||||
"""Test that provider.generate is called during run."""
|
||||
mock_provider.add_text_response("Response!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
assert len(mock_provider.calls) == 1
|
||||
call = mock_provider.calls[0]
|
||||
assert call["system_prompt"] == "You are a helpful assistant."
|
||||
assert len(call["history"]) == 1 # User message
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentGenerate:
|
||||
"""Tests for Agent.generate() method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_generate_returns_message(self, mock_provider):
|
||||
"""Test that generate returns a Message."""
|
||||
mock_provider.add_text_response("Generated response")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
message = await agent.generate("Hello")
|
||||
|
||||
assert message.role == "assistant"
|
||||
assert message.extract_text() == "Generated response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_generate_without_tool_loop(self, mock_provider):
|
||||
"""Test that generate doesn't do tool calling loop."""
|
||||
# Add tool call response
|
||||
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
|
||||
agent = Agent(provider=mock_provider, tools=[])
|
||||
|
||||
message = await agent.generate("Calculate 1+2")
|
||||
|
||||
# Should return the tool call without executing it
|
||||
assert message.has_tool_calls()
|
||||
assert len(message.tool_calls) == 1
|
||||
assert message.tool_calls[0].function.name == "add"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_generate_adds_to_history(self, mock_provider):
|
||||
"""Test that generate adds response to history."""
|
||||
mock_provider.add_text_response("Response!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.generate("Hello")
|
||||
|
||||
assert len(agent.history) == 2
|
||||
assert agent.history[1].role == "assistant"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentHistory:
|
||||
"""Tests for Agent history management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_history_property_returns_copy(self, mock_provider):
|
||||
"""Test that history property returns a copy."""
|
||||
mock_provider.add_text_response("Response!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
|
||||
history = agent.history
|
||||
history.clear() # Modify the copy
|
||||
|
||||
# Original should still have messages
|
||||
assert len(agent.history) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_clear_history(self, mock_provider):
|
||||
"""Test clearing history."""
|
||||
mock_provider.add_text_response("Response!")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
await agent.run("Hello")
|
||||
agent.clear_history()
|
||||
|
||||
assert agent.history == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_add_message(self, mock_provider):
|
||||
"""Test manually adding a message."""
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
from agentlite import Message
|
||||
|
||||
agent.add_message(Message(role="user", content="Manual message"))
|
||||
|
||||
assert len(agent.history) == 1
|
||||
assert agent.history[0].extract_text() == "Manual message"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentWithTools:
|
||||
"""Tests for Agent with tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_with_tools_initialization(self, mock_provider, add_tool):
|
||||
"""Test agent initialization with tools."""
|
||||
agent = Agent(
|
||||
provider=mock_provider, tools=[add_tool], system_prompt="You have access to tools."
|
||||
)
|
||||
|
||||
assert len(agent.tools.tools) == 1
|
||||
|
||||
# Run to verify tools are passed to provider
|
||||
mock_provider.add_text_response("I have tools available")
|
||||
await agent.run("Hello")
|
||||
|
||||
# Check that tools were passed to provider
|
||||
assert len(mock_provider.calls) == 1
|
||||
assert len(mock_provider.calls[0]["tools"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_tool_call_execution(self, mock_provider, add_tool):
|
||||
"""Test that agent executes tool calls."""
|
||||
# First response: tool call
|
||||
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
|
||||
# Second response: text after tool result
|
||||
mock_provider.add_text_response("The sum is 3")
|
||||
|
||||
agent = Agent(provider=mock_provider, tools=[add_tool])
|
||||
|
||||
response = await agent.run("What is 1+2?")
|
||||
|
||||
assert "3" in response
|
||||
# Should have made 2 calls to provider
|
||||
assert len(mock_provider.calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentMaxIterations:
|
||||
"""Tests for max_iterations behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_respects_max_iterations(self, mock_provider, add_tool):
|
||||
"""Test that agent stops after max_iterations."""
|
||||
# Always return tool calls to trigger iteration limit
|
||||
for _ in range(10):
|
||||
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
|
||||
|
||||
agent = Agent(provider=mock_provider, tools=[add_tool], max_iterations=3)
|
||||
|
||||
response = await agent.run("Calculate")
|
||||
|
||||
# Should stop after max_iterations
|
||||
assert len(mock_provider.calls) <= 3
|
||||
assert "Maximum tool call iterations reached" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_no_iterations_for_simple_response(self, mock_provider):
|
||||
"""Test that simple responses don't count as iterations."""
|
||||
mock_provider.add_text_response("Simple response")
|
||||
agent = Agent(provider=mock_provider, max_iterations=1)
|
||||
|
||||
response = await agent.run("Hello")
|
||||
|
||||
assert response == "Simple response"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
class TestAgentStreaming:
|
||||
"""Tests for streaming mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_run_streaming(self, mock_provider):
|
||||
"""Test streaming run."""
|
||||
mock_provider.add_text_response("Streamed response")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
stream = await agent.run("Hello", stream=True)
|
||||
|
||||
# Collect stream
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert "".join(chunks) == "Streamed response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_streaming_adds_to_history(self, mock_provider):
|
||||
"""Test that streaming adds messages to history."""
|
||||
mock_provider.add_text_response("Response")
|
||||
agent = Agent(provider=mock_provider)
|
||||
|
||||
stream = await agent.run("Hello", stream=True)
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
assert len(agent.history) == 2
|
||||
348
agentlite/tests/integration/test_with_api.py
Normal file
348
agentlite/tests/integration/test_with_api.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Integration tests for AgentLite with real API.
|
||||
|
||||
This script runs comprehensive tests against the real OpenAI API.
|
||||
Requires OPENAI_API_KEY environment variable to be set.
|
||||
|
||||
Usage:
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
python tests/integration/test_with_api.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
|
||||
|
||||
from agentlite import Agent, OpenAIProvider, LLMClient, llm_complete
|
||||
from agentlite.skills import discover_skills, SkillTool, index_skills_by_name
|
||||
from agentlite.tools import ConfigurableToolset
|
||||
|
||||
|
||||
# Test configuration
|
||||
TEST_MODEL = "gpt-4o-mini" # Use mini for cost efficiency
|
||||
HAS_OPENAI_API_KEY = bool(os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not HAS_OPENAI_API_KEY, reason="OPENAI_API_KEY is required to run integration tests"
|
||||
)
|
||||
|
||||
|
||||
def get_provider():
|
||||
"""Get OpenAI provider with API key."""
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ OPENAI_API_KEY not set!")
|
||||
print("Please set your OpenAI API key:")
|
||||
print(" export OPENAI_API_KEY='sk-...'")
|
||||
sys.exit(1)
|
||||
return OpenAIProvider(api_key=api_key, model=TEST_MODEL)
|
||||
|
||||
|
||||
async def test_basic_agent():
|
||||
"""Test 1: Basic Agent functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 1: Basic Agent Functionality")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
provider = get_provider()
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a helpful assistant. Be concise.",
|
||||
)
|
||||
|
||||
response = await agent.run("What is 2+2?")
|
||||
print(f"✅ Agent responded: {response[:100]}...")
|
||||
|
||||
assert "4" in response, "Expected '4' in response"
|
||||
print("✅ Basic Agent test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Basic Agent test FAILED: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def test_agent_with_tools():
|
||||
"""Test 2: Agent with tool suite."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 2: Agent with Tool Suite")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from agentlite.tools import ToolSuiteConfig, ReadFile, Glob
|
||||
|
||||
provider = get_provider()
|
||||
|
||||
# Create toolset with file tools
|
||||
config = ToolSuiteConfig()
|
||||
toolset = ConfigurableToolset(config, work_dir=Path.cwd())
|
||||
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a helpful assistant with file access.",
|
||||
tools=toolset.tools,
|
||||
)
|
||||
|
||||
print(f"✅ Agent created with {len(agent.tools.tools)} tools")
|
||||
|
||||
# Test simple query (without requiring file access)
|
||||
response = await agent.run("List the Python files in the current directory")
|
||||
print(f"✅ Agent with tools responded: {response[:100]}...")
|
||||
|
||||
print("✅ Agent with Tools test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Agent with Tools test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_llm_client():
|
||||
"""Test 3: LLMClient functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 3: LLMClient Functionality")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
provider = get_provider()
|
||||
client = LLMClient(provider=provider)
|
||||
|
||||
response = await client.complete(
|
||||
user_prompt="What is the capital of France?",
|
||||
system_prompt="You are a helpful assistant. Be concise.",
|
||||
)
|
||||
|
||||
print(f"✅ LLMClient responded: {response.content[:100]}...")
|
||||
assert "Paris" in response.content, "Expected 'Paris' in response"
|
||||
|
||||
print("✅ LLMClient test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLMClient test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_llm_streaming():
|
||||
"""Test 4: LLM streaming."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 4: LLM Streaming")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
provider = get_provider()
|
||||
client = LLMClient(provider=provider)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream(
|
||||
user_prompt="Count from 1 to 3",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
):
|
||||
chunks.append(chunk)
|
||||
print(f" Chunk: {chunk[:20]}...")
|
||||
|
||||
full_response = "".join(chunks)
|
||||
print(f"✅ Streamed response: {full_response[:100]}...")
|
||||
|
||||
print("✅ LLM Streaming test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLM Streaming test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_subagents():
|
||||
"""Test 5: Subagent functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 5: Subagent Functionality")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from agentlite.labor_market import LaborMarket
|
||||
from agentlite.tools.multiagent.task import Task
|
||||
|
||||
provider = get_provider()
|
||||
|
||||
# Create parent agent
|
||||
parent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a coordinator agent.",
|
||||
name="coordinator",
|
||||
)
|
||||
|
||||
# Create subagent
|
||||
coder = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a coding specialist. Write clean, simple code.",
|
||||
name="coder",
|
||||
)
|
||||
|
||||
# Add subagent to parent
|
||||
parent.add_subagent("coder", coder, "Writes code")
|
||||
|
||||
# Add Task tool
|
||||
parent.tools.add(Task(labor_market=parent.labor_market))
|
||||
|
||||
print(f"✅ Created parent with {len(parent.labor_market)} subagent(s)")
|
||||
print(f" Subagents: {parent.labor_market.list_subagents()}")
|
||||
|
||||
print("✅ Subagent test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Subagent test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_skills():
|
||||
"""Test 6: Skills functionality."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 6: Skills Functionality")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
# Discover example skills
|
||||
skills_dir = Path(__file__).parent.parent.parent / "examples" / "skills"
|
||||
if not skills_dir.exists():
|
||||
print("⚠️ Skills directory not found, skipping")
|
||||
return True
|
||||
|
||||
skills = discover_skills(skills_dir)
|
||||
print(f"✅ Discovered {len(skills)} skill(s)")
|
||||
|
||||
for skill in skills:
|
||||
print(f" - {skill.name} ({skill.type})")
|
||||
|
||||
if len(skills) == 0:
|
||||
print("⚠️ No skills found, skipping further tests")
|
||||
return True
|
||||
|
||||
# Test with agent
|
||||
provider = get_provider()
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
skill_index = index_skills_by_name(skills)
|
||||
skill_tool = SkillTool(skill_index, parent_agent=agent)
|
||||
agent.tools.add(skill_tool)
|
||||
|
||||
print(f"✅ Added SkillTool to agent")
|
||||
print("✅ Skills test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Skills test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def test_conversation_history():
|
||||
"""Test 7: Conversation history."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Test 7: Conversation History")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
provider = get_provider()
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
# First message
|
||||
response1 = await agent.run("My name is Alice")
|
||||
print(f"✅ Response 1: {response1[:50]}...")
|
||||
|
||||
# Second message (should remember context)
|
||||
response2 = await agent.run("What is my name?")
|
||||
print(f"✅ Response 2: {response2[:50]}...")
|
||||
|
||||
assert "Alice" in response2, "Expected agent to remember name"
|
||||
|
||||
print("✅ Conversation History test PASSED")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Conversation History test FAILED: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""Run all integration tests."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AgentLite Integration Tests with Real API")
|
||||
print("=" * 60)
|
||||
print(f"Model: {TEST_MODEL}")
|
||||
|
||||
# Check API key
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
print("\n❌ OPENAI_API_KEY not set!")
|
||||
print("\nTo run these tests, set your OpenAI API key:")
|
||||
print(" export OPENAI_API_KEY='sk-...'")
|
||||
print("\nGet your API key from: https://platform.openai.com/api-keys")
|
||||
return []
|
||||
|
||||
results = []
|
||||
|
||||
# Run all tests
|
||||
results.append(("Basic Agent", await test_basic_agent()))
|
||||
results.append(("Agent with Tools", await test_agent_with_tools()))
|
||||
results.append(("LLMClient", await test_llm_client()))
|
||||
results.append(("LLM Streaming", await test_llm_streaming()))
|
||||
results.append(("Subagents", await test_subagents()))
|
||||
results.append(("Skills", await test_skills()))
|
||||
results.append(("Conversation History", await test_conversation_history()))
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Test Summary")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = "✅ PASSED" if result else "❌ FAILED"
|
||||
print(f"{status}: {name}")
|
||||
|
||||
print(f"\n{passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 All tests passed!")
|
||||
else:
|
||||
print(f"\n⚠️ {total - passed} test(s) failed")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = asyncio.run(run_all_tests())
|
||||
|
||||
# Exit with error code if any tests failed
|
||||
if results and not all(r for _, r in results):
|
||||
sys.exit(1)
|
||||
0
agentlite/tests/mocks/__init__.py
Normal file
0
agentlite/tests/mocks/__init__.py
Normal file
0
agentlite/tests/scenarios/__init__.py
Normal file
0
agentlite/tests/scenarios/__init__.py
Normal file
141
agentlite/tests/scenarios/test_cli_debug.py
Normal file
141
agentlite/tests/scenarios/test_cli_debug.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Debug script to find CLI test hang cause."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
|
||||
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
from agentlite.tools.shell.shell import Shell, Params
|
||||
|
||||
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
|
||||
|
||||
|
||||
async def test_shell_directly():
|
||||
"""Test shell tool without agent."""
|
||||
print("\n=== Test 1: Shell tool directly ===")
|
||||
shell = Shell(timeout=10)
|
||||
|
||||
# Use Params dataclass
|
||||
result = await shell(Params(command="echo 'Hello'", timeout=5))
|
||||
print(f"Result: {result}")
|
||||
print(f"Output: {result.output if hasattr(result, 'output') else result}")
|
||||
return True
|
||||
|
||||
|
||||
async def test_agent_no_tools():
|
||||
"""Test agent without tools."""
|
||||
print("\n=== Test 2: Agent without tools ===")
|
||||
api_key = os.environ.get("SILICONFLOW_API_KEY")
|
||||
if not api_key:
|
||||
print("SILICONFLOW_API_KEY not set")
|
||||
return False
|
||||
|
||||
provider = OpenAIProvider(
|
||||
api_key=api_key,
|
||||
base_url=SILICONFLOW_BASE_URL,
|
||||
model=SILICONFLOW_MODEL,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="Reply briefly in one word.",
|
||||
max_iterations=3,
|
||||
)
|
||||
|
||||
print("Sending message to LLM...")
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent.run("Say hello."),
|
||||
timeout=60.0,
|
||||
)
|
||||
print(f"Response: {response[:100]}...")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
print("TIMEOUT in agent without tools!")
|
||||
return False
|
||||
|
||||
|
||||
async def test_agent_with_shell():
|
||||
"""Test agent with shell tool - the problematic case."""
|
||||
print("\n=== Test 3: Agent WITH shell tool ===")
|
||||
api_key = os.environ.get("SILICONFLOW_API_KEY")
|
||||
if not api_key:
|
||||
print("SILICONFLOW_API_KEY not set")
|
||||
return False
|
||||
|
||||
provider = OpenAIProvider(
|
||||
api_key=api_key,
|
||||
base_url=SILICONFLOW_BASE_URL,
|
||||
model=SILICONFLOW_MODEL,
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a shell assistant. Execute commands when asked. Keep responses brief.",
|
||||
tools=[Shell(timeout=10)],
|
||||
max_iterations=5, # Limit iterations
|
||||
)
|
||||
|
||||
print("Sending message with tool request...")
|
||||
print("This is where it might hang...")
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent.run("Run 'echo test' and tell me the result."),
|
||||
timeout=120.0,
|
||||
)
|
||||
print(f"Response: {response}")
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
print("TIMEOUT! Agent hung for 120 seconds")
|
||||
|
||||
# Check history to see what happened
|
||||
print(f"\nHistory length: {len(agent.history)}")
|
||||
for i, msg in enumerate(agent.history[-5:]):
|
||||
content_preview = str(msg.content)[:100] if msg.content else "None"
|
||||
print(f" [{i}] {msg.role}: {content_preview}...")
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("CLI Debug Test - Finding the hang cause")
|
||||
print("=" * 60)
|
||||
|
||||
results = []
|
||||
|
||||
# Test 1: Shell directly
|
||||
r1 = await test_shell_directly()
|
||||
results.append(("Shell directly", r1))
|
||||
print(f"Result: {'PASS' if r1 else 'FAIL'}")
|
||||
|
||||
# Test 2: Agent without tools
|
||||
r2 = await test_agent_no_tools()
|
||||
results.append(("Agent no tools", r2))
|
||||
print(f"Result: {'PASS' if r2 else 'FAIL'}")
|
||||
|
||||
# Test 3: Agent with shell (the problem)
|
||||
r3 = await test_agent_with_shell()
|
||||
results.append(("Agent with shell", r3))
|
||||
print(f"Result: {'PASS' if r3 else 'FAIL'}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for name, passed in results:
|
||||
status = "✅ PASS" if passed else "❌ FAIL"
|
||||
print(f" {name}: {status}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
221
agentlite/tests/scenarios/test_cli_debug_verbose.py
Normal file
221
agentlite/tests/scenarios/test_cli_debug_verbose.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""Debug script with detailed logging to find CLI test hang cause."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("debug")
|
||||
|
||||
# SiliconFlow DeepSeek-V3 (known good function calling support)
|
||||
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
SILICONFLOW_MODEL = "Pro/deepseek-ai/DeepSeek-V3.2"
|
||||
SILICONFLOW_API_KEY = "sk-eaxfgkkcuatochftxpevkyvltghigsrclzjzalybmaqycual"
|
||||
|
||||
|
||||
async def main():
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
from agentlite.tools.shell.shell import Shell
|
||||
from agentlite.message import Message
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("CLI Debug Test with DeepSeek-V3 (SiliconFlow)")
|
||||
logger.info("=" * 60)
|
||||
|
||||
api_key = os.environ.get("SILICONFLOW_API_KEY") or SILICONFLOW_API_KEY
|
||||
if not api_key:
|
||||
logger.error("SILICONFLOW_API_KEY not set")
|
||||
return
|
||||
|
||||
logger.info(f"Using model: {SILICONFLOW_MODEL}")
|
||||
|
||||
provider = OpenAIProvider(
|
||||
api_key=api_key,
|
||||
base_url=SILICONFLOW_BASE_URL,
|
||||
model=SILICONFLOW_MODEL,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=provider,
|
||||
system_prompt="You are a shell assistant. Execute commands when asked. Reply briefly.",
|
||||
tools=[Shell(timeout=10)],
|
||||
max_iterations=5,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
message = "Run 'echo test' and tell me the result."
|
||||
|
||||
logger.info(f"\n=== Starting Agent Run ===")
|
||||
logger.info(f"Message: {message}")
|
||||
logger.info(f"Max iterations: {agent.max_iterations}")
|
||||
logger.info(f"Tools: {[t.name for t in agent.tools.tools]}")
|
||||
|
||||
agent._history.append(Message(role="user", content=message))
|
||||
|
||||
iterations = 0
|
||||
final_response = None
|
||||
|
||||
while iterations < agent.max_iterations:
|
||||
iterations += 1
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
logger.info(f"\n{'=' * 50}")
|
||||
logger.info(f"ITERATION {iterations}/{agent.max_iterations} (elapsed: {elapsed:.1f}s)")
|
||||
logger.info(f"{'=' * 50}")
|
||||
|
||||
# Step 1: Call Provider
|
||||
logger.info(">>> Step 1: Calling provider.generate()...")
|
||||
step_start = time.time()
|
||||
|
||||
try:
|
||||
stream = await asyncio.wait_for(
|
||||
provider.generate(
|
||||
system_prompt=agent.system_prompt,
|
||||
tools=agent.tools.tools,
|
||||
history=agent._history,
|
||||
),
|
||||
timeout=60.0,
|
||||
)
|
||||
logger.info(f"<<< Provider returned stream in {time.time() - step_start:.2f}s")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("!!! Provider call TIMEOUT after 60s")
|
||||
final_response = "ERROR: Provider timeout"
|
||||
break
|
||||
|
||||
# Step 2: Collect stream parts
|
||||
logger.info(">>> Step 2: Collecting stream parts...")
|
||||
step_start = time.time()
|
||||
|
||||
from agentlite.message import TextPart, ToolCall, ContentPart
|
||||
|
||||
response_parts = []
|
||||
tool_calls = []
|
||||
chunk_count = 0
|
||||
|
||||
try:
|
||||
async for part in stream:
|
||||
chunk_count += 1
|
||||
if chunk_count % 10 == 0:
|
||||
logger.debug(f" Received chunk #{chunk_count}")
|
||||
|
||||
if isinstance(part, ToolCall):
|
||||
tool_calls.append(part)
|
||||
logger.info(
|
||||
f" ToolCall received: {part.function.name if hasattr(part, 'function') else part}"
|
||||
)
|
||||
elif isinstance(part, ContentPart):
|
||||
response_parts.append(part)
|
||||
if isinstance(part, TextPart):
|
||||
logger.debug(f" Text: {part.text[:50]}...")
|
||||
|
||||
logger.info(
|
||||
f"<<< Stream finished in {time.time() - step_start:.2f}s, {chunk_count} chunks"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("!!! Stream reading TIMEOUT")
|
||||
final_response = "ERROR: Stream timeout"
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"!!! Stream error: {type(e).__name__}: {e}")
|
||||
final_response = f"ERROR: Stream error - {e}"
|
||||
break
|
||||
|
||||
# Extract text
|
||||
response_text = ""
|
||||
for part in response_parts:
|
||||
if isinstance(part, TextPart):
|
||||
response_text += part.text
|
||||
logger.info(f"Response text ({len(response_text)} chars): {response_text[:100]}...")
|
||||
logger.info(f"Tool calls: {len(tool_calls)}")
|
||||
|
||||
# Add to history
|
||||
agent._history.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=response_parts,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Check if done
|
||||
if not tool_calls:
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"\n=== Agent completed in {elapsed:.2f}s, {iterations} iterations ===")
|
||||
final_response = response_text
|
||||
break
|
||||
|
||||
# Step 4: Execute tool calls
|
||||
logger.info(f"\n>>> Step 3: Executing {len(tool_calls)} tool calls...")
|
||||
step_start = time.time()
|
||||
|
||||
for i, tc in enumerate(tool_calls):
|
||||
func_name = tc.function.name if hasattr(tc, "function") else str(tc)
|
||||
func_args = tc.function.arguments if hasattr(tc, "function") else ""
|
||||
logger.info(f" Tool #{i + 1}: {func_name}")
|
||||
logger.info(f" Args: {func_args[:200]}...")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
agent.tools.handle(tc),
|
||||
timeout=30.0,
|
||||
)
|
||||
output = result.output if hasattr(result, "output") else str(result)
|
||||
is_error = result.is_error if hasattr(result, "is_error") else False
|
||||
logger.info(
|
||||
f" Result: is_error={is_error}, output_len={len(output) if output else 0}"
|
||||
)
|
||||
output_preview = output[:100] if output else "None"
|
||||
logger.info(f" Output preview: {output_preview}...")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f" !!! Tool execution TIMEOUT")
|
||||
output = "Tool execution timed out"
|
||||
is_error = True
|
||||
except Exception as e:
|
||||
logger.error(f" !!! Tool error: {type(e).__name__}: {e}")
|
||||
output = str(e)
|
||||
is_error = True
|
||||
|
||||
# Add tool result to history
|
||||
agent._history.append(
|
||||
Message(
|
||||
role="tool",
|
||||
content=output,
|
||||
tool_call_id=tc.id if hasattr(tc, "id") else f"tc_{i}",
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"<<< Tool execution finished in {time.time() - step_start:.2f}s")
|
||||
|
||||
# Check overall timeout
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > 90:
|
||||
logger.warning(f"!!! Overall timeout approaching ({elapsed:.1f}s)")
|
||||
final_response = f"Timeout after {iterations} iterations"
|
||||
break
|
||||
|
||||
if iterations >= agent.max_iterations:
|
||||
logger.warning(f"!!! Max iterations reached ({agent.max_iterations})")
|
||||
final_response = f"Max iterations ({agent.max_iterations}) reached"
|
||||
|
||||
logger.info(f"\n{'=' * 60}")
|
||||
logger.info(f"FINAL RESULT:")
|
||||
logger.info(f"{'=' * 60}")
|
||||
logger.info(f"{final_response}")
|
||||
logger.info(f"Total iterations: {iterations}")
|
||||
logger.info(f"Total time: {time.time() - start_time:.2f}s")
|
||||
logger.info(f"History length: {len(agent._history)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
349
agentlite/tests/scenarios/test_cli_operations_real_api.py
Normal file
349
agentlite/tests/scenarios/test_cli_operations_real_api.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""End-to-end test for complex CLI operations with real API.
|
||||
|
||||
This test simulates a realistic complex CLI task where an agent:
|
||||
1. Explores project structure using shell commands
|
||||
2. Searches for specific patterns using grep/glob
|
||||
3. Reads relevant files
|
||||
4. Creates analysis reports
|
||||
|
||||
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
from agentlite.tools import (
|
||||
ConfigurableToolset,
|
||||
ToolSuiteConfig,
|
||||
Shell,
|
||||
ReadFile,
|
||||
WriteFile,
|
||||
Glob,
|
||||
Grep,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Configuration from model_config.toml
|
||||
# =============================================================================
|
||||
|
||||
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
|
||||
|
||||
|
||||
def get_siliconflow_provider() -> OpenAIProvider | None:
|
||||
"""Create OpenAIProvider for SiliconFlow API."""
|
||||
api_key = os.environ.get("SILICONFLOW_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
return OpenAIProvider(
|
||||
api_key=api_key,
|
||||
base_url=SILICONFLOW_BASE_URL,
|
||||
model=SILICONFLOW_MODEL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_provider():
|
||||
"""Create real SiliconFlow provider."""
|
||||
provider = get_siliconflow_provider()
|
||||
if provider is None:
|
||||
pytest.skip("SILICONFLOW_API_KEY not set")
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_project():
|
||||
"""Create a mock project structure for testing."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_dir = Path(tmpdir) / "test_project"
|
||||
project_dir.mkdir()
|
||||
|
||||
# Create project structure
|
||||
(project_dir / "src").mkdir()
|
||||
(project_dir / "src" / "utils").mkdir()
|
||||
(project_dir / "tests").mkdir()
|
||||
(project_dir / "docs").mkdir()
|
||||
|
||||
# Create source files
|
||||
(project_dir / "src" / "main.py").write_text('''"""Main module."""
|
||||
from src.utils.helper import process_data
|
||||
from src.utils.logger import setup_logger
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
logger = setup_logger()
|
||||
data = [1, 2, 3, 4, 5]
|
||||
result = process_data(data)
|
||||
logger.info(f"Result: {result}")
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
''')
|
||||
|
||||
(project_dir / "src" / "__init__.py").write_text('"""Source package."""')
|
||||
|
||||
(project_dir / "src" / "utils" / "helper.py").write_text('''"""Helper utilities."""
|
||||
def process_data(data: list) -> list:
|
||||
"""Process input data."""
|
||||
return [x * 2 for x in data]
|
||||
|
||||
def validate_data(data: list) -> bool:
|
||||
"""Validate data format."""
|
||||
return all(isinstance(x, (int, float)) for x in data)
|
||||
''')
|
||||
|
||||
(project_dir / "src" / "utils" / "logger.py").write_text('''"""Logging utilities."""
|
||||
import logging
|
||||
|
||||
def setup_logger(name: str = "app") -> logging.Logger:
|
||||
"""Setup application logger."""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
''')
|
||||
|
||||
(project_dir / "src" / "utils" / "__init__.py").write_text('"""Utils package."""')
|
||||
|
||||
# Create test files
|
||||
(project_dir / "tests" / "test_helper.py").write_text('''"""Tests for helper module."""
|
||||
from src.utils.helper import process_data, validate_data
|
||||
|
||||
def test_process_data():
|
||||
assert process_data([1, 2, 3]) == [2, 4, 6]
|
||||
|
||||
def test_validate_data():
|
||||
assert validate_data([1, 2, 3]) == True
|
||||
assert validate_data(["a", "b"]) == False
|
||||
''')
|
||||
|
||||
# Create documentation
|
||||
(project_dir / "docs" / "README.md").write_text("""# Test Project
|
||||
|
||||
A sample project for testing CLI operations.
|
||||
|
||||
## Structure
|
||||
|
||||
- `src/` - Source code
|
||||
- `tests/` - Unit tests
|
||||
- `docs/` - Documentation
|
||||
""")
|
||||
|
||||
(project_dir / "README.md").write_text("""# Test Project
|
||||
|
||||
Simple data processing project.
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python -m src.main
|
||||
```
|
||||
""")
|
||||
|
||||
yield project_dir
|
||||
|
||||
|
||||
@pytest.mark.scenario
|
||||
@pytest.mark.slow
|
||||
class TestComplexCLITasks:
|
||||
"""End-to-end tests with complex CLI operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explore_project_structure(self, real_provider, test_project):
|
||||
"""Test exploring project structure using CLI tools.
|
||||
|
||||
Task: Use shell commands to explore the project structure,
|
||||
then summarize what files exist.
|
||||
"""
|
||||
# Create toolset with Shell tool
|
||||
toolset = ConfigurableToolset(
|
||||
config=ToolSuiteConfig(
|
||||
shell_tools=ToolSuiteConfig().shell_tools,
|
||||
),
|
||||
work_dir=str(test_project),
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=toolset.tools,
|
||||
system_prompt=(
|
||||
"你是一个项目分析助手。使用 Shell 工具执行命令来探索项目结构。"
|
||||
"请使用 find、ls、tree 等命令来了解项目。"
|
||||
),
|
||||
max_iterations=5, # Limit iterations to prevent hanging
|
||||
)
|
||||
|
||||
# Add overall timeout to prevent infinite hanging
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent.run(
|
||||
f"探索项目目录 {test_project} 的结构,列出所有文件和目录,并总结项目的组织方式。"
|
||||
),
|
||||
timeout=120.0, # 2 minute overall timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Agent timed out after 120 seconds - possible infinite loop")
|
||||
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[项目结构探索结果]:\n{response}\n")
|
||||
|
||||
# Verify response mentions key files
|
||||
response_lower = response.lower()
|
||||
assert any(
|
||||
word in response_lower for word in ["src", "tests", "main.py", "helper", "logger"]
|
||||
), "Response should mention project files"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_and_analyze_code(self, real_provider, test_project):
|
||||
"""Test searching for patterns and analyzing code.
|
||||
|
||||
Task: Use grep/glob to find specific patterns,
|
||||
read the files, and create an analysis report.
|
||||
"""
|
||||
# Create toolset with all file tools
|
||||
toolset = ConfigurableToolset(
|
||||
config=ToolSuiteConfig(
|
||||
file_tools=ToolSuiteConfig().file_tools,
|
||||
shell_tools=ToolSuiteConfig().shell_tools,
|
||||
),
|
||||
work_dir=str(test_project),
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=toolset.tools,
|
||||
system_prompt=(
|
||||
"你是一个代码分析助手。使用 Glob、Grep、ReadFile 等工具来搜索和分析代码。"
|
||||
"请使用 Shell 工具执行 grep、find 等命令。"
|
||||
),
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
f"在项目 {test_project} 中搜索所有包含 'def ' 的 Python 文件,"
|
||||
f"列出找到的函数定义,并创建一个函数清单文件保存到 {test_project}/functions.txt。"
|
||||
)
|
||||
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[代码搜索分析结果]:\n{response}\n")
|
||||
|
||||
# Check if analysis file was created
|
||||
functions_file = test_project / "functions.txt"
|
||||
if functions_file.exists():
|
||||
content = functions_file.read_text()
|
||||
print(f"\n[函数清单文件]:\n{content}\n")
|
||||
assert len(content) > 0, "Functions file should not be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_multi_step_task(self, real_provider, test_project):
|
||||
"""Test a complex multi-step CLI task.
|
||||
|
||||
Task:
|
||||
1. Find all Python files using shell
|
||||
2. Search for TODO comments using grep
|
||||
3. Read files with TODOs
|
||||
4. Create a summary report
|
||||
"""
|
||||
# Add some TODO comments
|
||||
todo_file = test_project / "src" / "utils" / "todo_items.py"
|
||||
todo_file.write_text('''"""Module with TODO items."""
|
||||
|
||||
# TODO: Implement error handling
|
||||
def risky_operation(data):
|
||||
"""Perform a risky operation."""
|
||||
return data / 0 # This will fail
|
||||
|
||||
# TODO: Add caching mechanism
|
||||
def expensive_computation(n):
|
||||
"""Perform expensive computation."""
|
||||
return sum(range(n))
|
||||
|
||||
# FIXME: Memory leak in this function
|
||||
def process_large_file(path):
|
||||
"""Process a large file."""
|
||||
with open(path) as f:
|
||||
return f.read()
|
||||
''')
|
||||
|
||||
# Create comprehensive toolset
|
||||
toolset = ConfigurableToolset(
|
||||
config=ToolSuiteConfig(
|
||||
file_tools=ToolSuiteConfig().file_tools,
|
||||
shell_tools=ToolSuiteConfig().shell_tools,
|
||||
),
|
||||
work_dir=str(test_project),
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=toolset.tools,
|
||||
system_prompt=(
|
||||
"你是一个项目维护助手。"
|
||||
"使用 Shell 工具执行命令(如 find、grep、ls 等)。"
|
||||
"使用 ReadFile 读取文件内容。"
|
||||
"使用 WriteFile 创建新文件。"
|
||||
"请一步一步完成任务。"
|
||||
),
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
f"请完成以下任务:\n"
|
||||
f"1. 使用 'find' 命令找出项目 {test_project} 中所有的 .py 文件\n"
|
||||
f"2. 使用 'grep' 命令搜索所有包含 'TODO' 或 'FIXME' 的行\n"
|
||||
f"3. 读取包含 TODO 的文件内容\n"
|
||||
f"4. 创建一个 TODO 报告文件,保存到 {test_project}/todo_report.txt"
|
||||
)
|
||||
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[复杂任务结果]:\n{response}\n")
|
||||
|
||||
# Verify report was created
|
||||
report_file = test_project / "todo_report.txt"
|
||||
if report_file.exists():
|
||||
content = report_file.read_text()
|
||||
print(f"\n[TODO 报告]:\n{content}\n")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shell_pipes_and_chains(self, real_provider, test_project):
|
||||
"""Test complex shell commands with pipes and chains.
|
||||
|
||||
Task: Use shell pipes to perform complex data processing.
|
||||
"""
|
||||
toolset = ConfigurableToolset(
|
||||
config=ToolSuiteConfig(
|
||||
shell_tools=ToolSuiteConfig().shell_tools,
|
||||
),
|
||||
work_dir=str(test_project),
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=toolset.tools,
|
||||
system_prompt=(
|
||||
"你是一个 Shell 命令专家。"
|
||||
"使用复杂的 Shell 命令(管道、重定向、条件执行等)来完成任务。"
|
||||
),
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
f"在项目目录 {test_project} 中执行以下操作:\n"
|
||||
f"1. 使用 'find . -name \"*.py\" | xargs wc -l' 统计所有 Python 文件的总行数\n"
|
||||
f'2. 使用 \'grep -r "def " --include="*.py" | wc -l\' 统计函数定义数量\n'
|
||||
f"3. 使用 'ls -la' 查看目录详情\n"
|
||||
f"报告你的发现。"
|
||||
)
|
||||
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[Shell 管道命令结果]:\n{response}\n")
|
||||
|
||||
# Verify response contains relevant information
|
||||
response_lower = response.lower()
|
||||
assert any(
|
||||
word in response_lower for word in ["行", "line", "函数", "function", "文件", "file"]
|
||||
), "Response should mention analysis results"
|
||||
374
agentlite/tests/scenarios/test_file_operations.py
Normal file
374
agentlite/tests/scenarios/test_file_operations.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""End-to-end scenario test for file operations.
|
||||
|
||||
This test simulates a realistic scenario where an agent:
|
||||
1. Reads a file
|
||||
2. Explains its content
|
||||
3. Creates a new file with analysis results
|
||||
|
||||
This is a meaningful e2e test that demonstrates the agent's ability to
|
||||
orchestrate multiple tool calls in sequence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import Agent, TextPart, tool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Operation Tools
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@tool()
|
||||
async def read_file(file_path: str) -> str:
|
||||
"""Read the content of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to read.
|
||||
|
||||
Returns:
|
||||
The content of the file as a string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
"""
|
||||
with open(file_path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@tool()
|
||||
async def write_file(file_path: str, content: str) -> str:
|
||||
"""Write content to a file, creating it if it doesn't exist.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to write.
|
||||
content: Content to write to the file.
|
||||
|
||||
Returns:
|
||||
Success message confirming the file was written.
|
||||
"""
|
||||
# Create parent directories if they don't exist
|
||||
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
return f"File successfully written to {file_path}"
|
||||
|
||||
|
||||
@tool()
|
||||
async def list_files(directory: str) -> str:
|
||||
"""List all files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to the directory to list.
|
||||
|
||||
Returns:
|
||||
A newline-separated list of file names in the directory.
|
||||
"""
|
||||
files = os.listdir(directory)
|
||||
return "\n".join(files)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# E2E Test
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.scenario
|
||||
class TestFileOperationsScenario:
|
||||
"""End-to-end test for file read/write operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_explain_and_write(self, mock_provider):
|
||||
"""Test a complete workflow: read file -> explain -> write results."""
|
||||
# Setup: Create a temporary file with content
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a source file to read
|
||||
source_file = os.path.join(tmpdir, "source.txt")
|
||||
source_content = """Project Overview
|
||||
================
|
||||
|
||||
This is a sample project document for testing.
|
||||
|
||||
Features:
|
||||
- Feature A: Does something useful
|
||||
- Feature B: Does something else
|
||||
- Feature C: The most important feature
|
||||
|
||||
Conclusion: This project demonstrates file operations.
|
||||
"""
|
||||
with open(source_file, "w") as f:
|
||||
f.write(source_content)
|
||||
|
||||
# Configure mock provider responses
|
||||
# The agent should:
|
||||
# 1. Read the file
|
||||
# 2. Summarize it
|
||||
# 3. Write the summary to a new file
|
||||
mock_provider.add_text_response(
|
||||
f"I'll read the file at {source_file} and analyze it for you."
|
||||
)
|
||||
|
||||
# Create agent with file tools
|
||||
tools = [read_file, write_file, list_files]
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=tools,
|
||||
system_prompt="You are a helpful file analysis assistant.",
|
||||
)
|
||||
|
||||
# Step 1: Agent reads and analyzes the file
|
||||
mock_provider.clear_responses()
|
||||
mock_provider.add_tool_call(
|
||||
"read_file",
|
||||
{"file_path": source_file},
|
||||
source_content,
|
||||
)
|
||||
|
||||
# Agent analyzes the content
|
||||
mock_provider.add_text_response(
|
||||
"I've read the file. It's a project overview document with 3 features. "
|
||||
"Let me create a summary file."
|
||||
)
|
||||
|
||||
# Step 2: Agent writes summary to a new file
|
||||
summary_file = os.path.join(tmpdir, "summary.txt")
|
||||
expected_summary = """Project Summary
|
||||
================
|
||||
|
||||
This is a sample project with 3 main features:
|
||||
- Feature A, - Feature B, - Feature C
|
||||
|
||||
The most important feature is Feature C.
|
||||
"""
|
||||
|
||||
mock_provider.clear_responses()
|
||||
mock_provider.add_tool_call(
|
||||
"write_file",
|
||||
{
|
||||
"file_path": summary_file,
|
||||
"content": expected_summary,
|
||||
},
|
||||
f"File successfully written to {summary_file}",
|
||||
)
|
||||
mock_provider.add_text_response(f"I've created a summary at {summary_file}")
|
||||
|
||||
# Execute the agent
|
||||
response = await agent.run(
|
||||
f"Please read {source_file}, analyze it, and create a summary file at {summary_file}"
|
||||
)
|
||||
|
||||
# Verify the interaction
|
||||
assert "summary" in response.lower()
|
||||
|
||||
# Verify the provider was called correctly
|
||||
assert len(mock_provider.calls) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_scenario(self, mock_provider):
|
||||
"""Test listing files in a directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create some test files
|
||||
for i in range(3):
|
||||
with open(os.path.join(tmpdir, f"file{i}.txt"), "w") as f:
|
||||
f.write(f"Content {i}")
|
||||
|
||||
# Configure agent to list files
|
||||
mock_provider.add_tool_call(
|
||||
"list_files",
|
||||
{"directory": tmpdir},
|
||||
"file0.txt\nfile1.txt\nfile2.txt",
|
||||
)
|
||||
mock_provider.add_text_response(
|
||||
f"I found 3 files in {tmpdir}: file0.txt, file1.txt, file2.txt"
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=[list_files],
|
||||
system_prompt="You are a file system assistant.",
|
||||
)
|
||||
|
||||
response = await agent.run(f"List all files in {tmpdir}")
|
||||
|
||||
assert "3 files" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step_file_workflow(self, mock_provider):
|
||||
"""Test a complex multi-step file workflow.
|
||||
|
||||
Scenario:
|
||||
1. List files in directory
|
||||
2. Read each file
|
||||
3. Create a combined report
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create test files
|
||||
files_content = {
|
||||
"report1.txt": "Sales increased by 20%",
|
||||
"report2.txt": "Customer satisfaction at 85%",
|
||||
"report3.txt": "Bug fixes: 15 resolved",
|
||||
}
|
||||
|
||||
for name, content in files_content.items():
|
||||
with open(os.path.join(tmpdir, name), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
# Configure agent responses for multi-step workflow
|
||||
tools = [read_file, write_file, list_files]
|
||||
|
||||
# Step 1: List files
|
||||
mock_provider.add_tool_call(
|
||||
"list_files",
|
||||
{"directory": tmpdir},
|
||||
"report1.txt\nreport2.txt\nreport3.txt",
|
||||
)
|
||||
|
||||
# Step 2: Read all files
|
||||
mock_provider.add_tool_call(
|
||||
"read_file",
|
||||
{"file_path": os.path.join(tmpdir, "report1.txt")},
|
||||
"Sales increased by 20%",
|
||||
)
|
||||
mock_provider.add_tool_call(
|
||||
"read_file",
|
||||
{"file_path": os.path.join(tmpdir, "report2.txt")},
|
||||
"Customer satisfaction at 85%",
|
||||
)
|
||||
mock_provider.add_tool_call(
|
||||
"read_file",
|
||||
{"file_path": os.path.join(tmpdir, "report3.txt")},
|
||||
"Bug fixes: 15 resolved",
|
||||
)
|
||||
|
||||
# Step 3: Write combined report
|
||||
combined_report = """Combined Report
|
||||
================
|
||||
|
||||
1. Sales: Increased by 20%
|
||||
2. Customer Satisfaction: 85%
|
||||
3. Development: 15 bugs resolved
|
||||
"""
|
||||
mock_provider.add_tool_call(
|
||||
"write_file",
|
||||
{
|
||||
"file_path": os.path.join(tmpdir, "combined_report.txt"),
|
||||
"content": combined_report,
|
||||
},
|
||||
f"File successfully written to {os.path.join(tmpdir, 'combined_report.txt')}",
|
||||
)
|
||||
|
||||
mock_provider.add_text_response(
|
||||
"I've created a combined report summarizing all three reports."
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=tools,
|
||||
system_prompt="You are a report analyst assistant.",
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
f"List all files in {tmpdir}, read them all, and create a combined report at combined_report.txt"
|
||||
)
|
||||
|
||||
assert "combined report" in response.lower()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Additional Tools for Extended Scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@tool()
|
||||
async def count_words(file_path: str) -> str:
|
||||
"""Count the number of words in a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to analyze.
|
||||
|
||||
Returns:
|
||||
The word count as a string.
|
||||
"""
|
||||
with open(file_path) as f:
|
||||
content = f.read()
|
||||
word_count = len(content.split())
|
||||
return f"Word count: {word_count}"
|
||||
|
||||
|
||||
@tool()
|
||||
async def append_to_file(file_path: str, content: str) -> str:
|
||||
"""Append content to an existing file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to append to.
|
||||
content: Content to append.
|
||||
|
||||
Returns:
|
||||
Success message.
|
||||
"""
|
||||
with open(file_path, "a") as f:
|
||||
f.write("\n" + content)
|
||||
return f"Content appended to {file_path}"
|
||||
|
||||
|
||||
@pytest.mark.scenario
|
||||
class TestExtendedFileOperations:
|
||||
"""Extended scenarios with more file operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_count_and_append(self, mock_provider):
|
||||
"""Test reading a file, counting words, and appending a note."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
source_file = os.path.join(tmpdir, "document.txt")
|
||||
with open(source_file, "w") as f:
|
||||
f.write("This is a test document with several words in it.")
|
||||
|
||||
tools = [read_file, write_file, count_words, append_to_file]
|
||||
|
||||
# Step 1: Read file
|
||||
mock_provider.add_tool_call(
|
||||
"read_file",
|
||||
{"file_path": source_file},
|
||||
"This is a test document with several words in it.",
|
||||
)
|
||||
|
||||
# Step 2: Count words
|
||||
mock_provider.add_tool_call(
|
||||
"count_words",
|
||||
{"file_path": source_file},
|
||||
"Word count: 10",
|
||||
)
|
||||
|
||||
# Step 3: Append analysis
|
||||
mock_provider.add_tool_call(
|
||||
"append_to_file",
|
||||
{
|
||||
"file_path": source_file,
|
||||
"content": "\n\n[Analysis] This document contains 10 words.",
|
||||
},
|
||||
f"Content appended to {source_file}",
|
||||
)
|
||||
|
||||
mock_provider.add_text_response(
|
||||
"I've analyzed the document and appended the word count analysis."
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=tools,
|
||||
system_prompt="You are a document analysis assistant.",
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
f"Read {source_file}, count its words, and append the word count as an analysis note"
|
||||
)
|
||||
|
||||
assert "analyzed" in response.lower()
|
||||
226
agentlite/tests/scenarios/test_file_operations_real_api.py
Normal file
226
agentlite/tests/scenarios/test_file_operations_real_api.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""End-to-end scenario test for file operations with real API.
|
||||
|
||||
This test simulates a realistic scenario where an agent:
|
||||
1. Reads a file
|
||||
2. Explains its content
|
||||
3. Creates a new file with analysis results
|
||||
|
||||
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import Agent, OpenAIProvider, tool
|
||||
|
||||
# =============================================================================
|
||||
# Configuration from model_config.toml
|
||||
# =============================================================================
|
||||
|
||||
# SiliconFlow API configuration (matches qwen35_397b in model_config.toml)
|
||||
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
|
||||
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
|
||||
|
||||
|
||||
def get_siliconflow_provider() -> OpenAIProvider | None:
|
||||
"""Create OpenAIProvider for SiliconFlow API.
|
||||
|
||||
Returns None if SILICONFLOW_API_KEY is not set.
|
||||
"""
|
||||
api_key = os.environ.get("SILICONFLOW_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
return OpenAIProvider(
|
||||
api_key=api_key,
|
||||
base_url=SILICONFLOW_BASE_URL,
|
||||
model=SILICONFLOW_MODEL,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# File Operation Tools
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@tool()
|
||||
async def read_file(file_path: str) -> str:
|
||||
"""Read the content of a file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to read.
|
||||
|
||||
Returns:
|
||||
The content of the file as a string.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
"""
|
||||
with open(file_path) as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@tool()
|
||||
async def write_file(file_path: str, content: str) -> str:
|
||||
"""Write content to a file, creating it if it doesn't exist.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to write.
|
||||
content: Content to write to the file.
|
||||
|
||||
Returns:
|
||||
Success message confirming the file was written.
|
||||
"""
|
||||
# Create parent directories if they don't exist
|
||||
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
return f"File successfully written to {file_path}"
|
||||
|
||||
|
||||
@tool()
|
||||
async def list_files(directory: str) -> str:
|
||||
"""List all files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to the directory to list.
|
||||
|
||||
Returns:
|
||||
A newline-separated list of file names in the directory.
|
||||
"""
|
||||
files = os.listdir(directory)
|
||||
return "\n".join(files)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Real API E2E Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_provider():
|
||||
"""Create a real SiliconFlow provider.
|
||||
|
||||
Skip tests if SILICONFLOW_API_KEY is not set.
|
||||
"""
|
||||
provider = get_siliconflow_provider()
|
||||
if provider is None:
|
||||
pytest.skip("SILICONFLOW_API_KEY not set, skipping real API tests")
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.scenario
|
||||
@pytest.mark.expensive
|
||||
class TestFileOperationsWithRealAPI:
|
||||
"""End-to-end tests with real SiliconFlow API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_and_summarize(self, real_provider):
|
||||
"""Test reading a file and creating a summary with real API."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a source file with meaningful content
|
||||
source_file = os.path.join(tmpdir, "source.txt")
|
||||
source_content = """AgentLite 项目概述
|
||||
================
|
||||
|
||||
AgentLite 是一个轻量级的 Agent 组件库,主要特点:
|
||||
- 异步优先设计
|
||||
- OpenAI 兼容 API
|
||||
- 工具系统 (支持 MCP)
|
||||
- 流式响应支持
|
||||
|
||||
使用示例:
|
||||
```python
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
|
||||
provider = OpenAIProvider(api_key="...", model="gpt-4")
|
||||
agent = Agent(provider=provider)
|
||||
response = await agent.run("Hello!")
|
||||
```
|
||||
"""
|
||||
with open(source_file, "w") as f:
|
||||
f.write(source_content)
|
||||
|
||||
# Create agent with file tools
|
||||
tools = [read_file, write_file, list_files]
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=tools,
|
||||
system_prompt="你是一个文件分析助手。请使用工具来完成任务。",
|
||||
)
|
||||
|
||||
# Run the agent to read, analyze, and write summary
|
||||
output_file = os.path.join(tmpdir, "summary.txt")
|
||||
response = await agent.run(
|
||||
f"请读取 {source_file} 文件,分析其内容,并创建一个摘要文件保存到 {output_file}。"
|
||||
)
|
||||
|
||||
# Verify the agent responded
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[Agent 响应]:\n{response}\n")
|
||||
|
||||
# Verify the output file was created
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file) as f:
|
||||
output_content = f.read()
|
||||
print(f"\n[输出文件内容]:\n{output_content}\n")
|
||||
assert len(output_content) > 0, "Output file should not be empty"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_and_combine(self, real_provider):
|
||||
"""Test listing files, reading them, and creating combined report."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create multiple files
|
||||
files = {
|
||||
"sales.txt": "销售额增长了 20%",
|
||||
"users.txt": "用户满意度达到 85%",
|
||||
"bugs.txt": "修复了 15 个问题",
|
||||
}
|
||||
for name, content in files.items():
|
||||
with open(os.path.join(tmpdir, name), "w") as f:
|
||||
f.write(content)
|
||||
|
||||
# Create agent with file tools
|
||||
tools = [read_file, write_file, list_files]
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
tools=tools,
|
||||
system_prompt="你是一个数据分析助手。请使用工具来完成任务。",
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
report_file = os.path.join(tmpdir, "report.txt")
|
||||
response = await agent.run(
|
||||
f"列出 {tmpdir} 目录中的所有文件,读取每个文件的内容,然后创建一份综合报告保存到 {report_file}。"
|
||||
)
|
||||
|
||||
# Verify the agent responded
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[Agent 响应]:\n{response}\n")
|
||||
|
||||
# The agent should have created the report file
|
||||
if os.path.exists(report_file):
|
||||
with open(report_file) as f:
|
||||
report_content = f.read()
|
||||
print(f"\n[报告文件内容]:\n{report_content}\n")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_conversation(self, real_provider):
|
||||
"""Test basic conversation without tools."""
|
||||
agent = Agent(
|
||||
provider=real_provider,
|
||||
system_prompt="你是一个有帮助的助手。请用中文回答。",
|
||||
)
|
||||
|
||||
response = await agent.run("你好!请简单介绍一下你自己。")
|
||||
|
||||
assert response, "Agent should return a response"
|
||||
print(f"\n[Agent 自我介绍]:\n{response}\n")
|
||||
assert len(response) > 10, "Response should be meaningful"
|
||||
521
agentlite/tests/tools/test_document_kg_tools.py
Normal file
521
agentlite/tests/tools/test_document_kg_tools.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""工具专项测试 - 文档提取和知识图谱工具
|
||||
|
||||
本模块测试基于数据基底的工具功能,包括:
|
||||
1. 文档读取和解析工具
|
||||
2. 实体提取工具
|
||||
3. 知识图谱查询工具
|
||||
4. 推理工具
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from agentlite import Agent, Message, TextPart, tool
|
||||
|
||||
|
||||
def tool_output(result: Any) -> Any:
|
||||
"""兼容旧式返回值和 ToolResult 返回值."""
|
||||
return getattr(result, "output", result)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 数据加载 fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
"""返回测试数据目录路径."""
|
||||
return Path(__file__).parent.parent / "data"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_article(data_dir: Path) -> str:
|
||||
"""加载样例文章."""
|
||||
return (data_dir / "documents" / "sample_article.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def technical_spec(data_dir: Path) -> str:
|
||||
"""加载技术规范文档."""
|
||||
return (data_dir / "documents" / "technical_spec.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_notes(data_dir: Path) -> str:
|
||||
"""加载会议记录."""
|
||||
return (data_dir / "documents" / "meeting_notes.txt").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph_entities(data_dir: Path) -> dict[str, Any]:
|
||||
"""加载知识图谱实体数据."""
|
||||
with open(data_dir / "knowledge_base" / "entities.json") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph_relations(data_dir: Path) -> dict[str, Any]:
|
||||
"""加载知识图谱关系数据."""
|
||||
with open(data_dir / "knowledge_base" / "relations.json") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_queries(data_dir: Path) -> list[dict[str, Any]]:
|
||||
"""加载图谱查询测试用例."""
|
||||
with open(data_dir / "knowledge_base" / "graph_queries.yaml") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("queries", [])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 知识图谱工具实现
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KnowledgeGraph:
|
||||
"""知识图谱内存存储."""
|
||||
|
||||
def __init__(self, entities: list[dict], relations: list[dict]):
|
||||
self._entities = {e["id"]: e for e in entities}
|
||||
self._relations = relations
|
||||
self._index_by_type: dict[str, list[str]] = {}
|
||||
self._index_by_name: dict[str, str] = {}
|
||||
|
||||
# 构建索引
|
||||
for entity_id, entity in self._entities.items():
|
||||
entity_type = entity.get("type", "Unknown")
|
||||
if entity_type not in self._index_by_type:
|
||||
self._index_by_type[entity_type] = []
|
||||
self._index_by_type[entity_type].append(entity_id)
|
||||
|
||||
name = entity.get("name", "")
|
||||
if name:
|
||||
self._index_by_name[name] = entity_id
|
||||
|
||||
def get_entity(self, entity_id: str) -> dict | None:
|
||||
"""获取实体."""
|
||||
return self._entities.get(entity_id)
|
||||
|
||||
def get_entity_by_name(self, name: str) -> dict | None:
|
||||
"""通过名称获取实体."""
|
||||
entity_id = self._index_by_name.get(name)
|
||||
if entity_id:
|
||||
return self._entities.get(entity_id)
|
||||
return None
|
||||
|
||||
def get_entities_by_type(self, entity_type: str) -> list[dict]:
|
||||
"""获取特定类型的所有实体."""
|
||||
entity_ids = self._index_by_type.get(entity_type, [])
|
||||
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
|
||||
|
||||
def get_relations(
|
||||
self, from_id: str | None = None, to_id: str | None = None, relation_type: str | None = None
|
||||
) -> list[dict]:
|
||||
"""获取关系."""
|
||||
results = []
|
||||
for rel in self._relations:
|
||||
if from_id and rel.get("from") != from_id:
|
||||
continue
|
||||
if to_id and rel.get("to") != to_id:
|
||||
continue
|
||||
if relation_type and rel.get("type") != relation_type:
|
||||
continue
|
||||
results.append(rel)
|
||||
return results
|
||||
|
||||
def get_neighbors(self, entity_id: str, relation_type: str | None = None) -> list[dict]:
|
||||
"""获取邻居实体."""
|
||||
relations = self.get_relations(from_id=entity_id, relation_type=relation_type)
|
||||
neighbors = []
|
||||
for rel in relations:
|
||||
target_id = rel.get("to")
|
||||
if target_id and target_id in self._entities:
|
||||
neighbors.append({"entity": self._entities[target_id], "relation": rel})
|
||||
return neighbors
|
||||
|
||||
def find_path(self, start_id: str, end_id: str, max_depth: int = 3) -> list[list[str]] | None:
|
||||
"""查找两个实体之间的路径."""
|
||||
if start_id == end_id:
|
||||
return [[start_id]]
|
||||
|
||||
if max_depth <= 0:
|
||||
return None
|
||||
|
||||
# BFS
|
||||
from collections import deque
|
||||
|
||||
queue = deque([(start_id, [start_id])])
|
||||
visited = {start_id}
|
||||
all_paths = []
|
||||
|
||||
while queue:
|
||||
current_id, path = queue.popleft()
|
||||
|
||||
if len(path) > max_depth + 1:
|
||||
continue
|
||||
|
||||
relations = self.get_relations(from_id=current_id)
|
||||
for rel in relations:
|
||||
next_id = rel.get("to")
|
||||
if not next_id:
|
||||
continue
|
||||
|
||||
new_path = path + [next_id]
|
||||
|
||||
if next_id == end_id:
|
||||
all_paths.append(new_path)
|
||||
elif next_id not in visited and len(new_path) <= max_depth:
|
||||
visited.add(next_id)
|
||||
queue.append((next_id, new_path))
|
||||
|
||||
return all_paths if all_paths else None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph(knowledge_graph_entities, knowledge_graph_relations) -> KnowledgeGraph:
|
||||
"""创建知识图谱实例."""
|
||||
return KnowledgeGraph(
|
||||
entities=knowledge_graph_entities.get("entities", []),
|
||||
relations=knowledge_graph_relations.get("relations", []),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 工具定义
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@tool()
|
||||
async def read_document(file_path: str) -> str:
|
||||
"""读取文档内容.
|
||||
|
||||
Args:
|
||||
file_path: 文档路径
|
||||
|
||||
Returns:
|
||||
文档内容
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return f"Error: File not found: {file_path}"
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
|
||||
@tool()
|
||||
async def extract_entities(text: str) -> str:
|
||||
"""从文本中提取实体.
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
JSON 格式的实体列表
|
||||
"""
|
||||
# 简化的实体提取 - 实际应使用 NLP 模型
|
||||
import re
|
||||
|
||||
entities = []
|
||||
|
||||
# 提取人名(简单的中文姓名匹配)
|
||||
person_pattern = r"[\u4e00-\u9fa5]{2,4}"
|
||||
potential_names = re.findall(person_pattern, text)
|
||||
common_names = ["张三", "李四", "王五", "赵六", "李飞飞", "吴恩达", "Yann LeCun"]
|
||||
|
||||
for name in potential_names:
|
||||
if name in common_names or len(name) == 3:
|
||||
entities.append({"type": "Person", "name": name})
|
||||
|
||||
# 提取公司/组织名
|
||||
org_pattern = r"(TechCorp|OpenAI|GitHub|Google)"
|
||||
orgs = re.findall(org_pattern, text)
|
||||
for org in set(orgs):
|
||||
entities.append({"type": "Organization", "name": org})
|
||||
|
||||
# 提取技术术语
|
||||
tech_pattern = r"(Python|TensorFlow|PyTorch|GPT-4|AI|LLM)"
|
||||
techs = re.findall(tech_pattern, text)
|
||||
for tech in set(techs):
|
||||
entities.append({"type": "Technology", "name": tech})
|
||||
|
||||
return json.dumps(entities, ensure_ascii=False)
|
||||
|
||||
|
||||
@tool()
|
||||
async def query_knowledge_graph(query_type: str, params: str) -> str:
|
||||
"""查询知识图谱.
|
||||
|
||||
Args:
|
||||
query_type: 查询类型 (person_relations, company_employees, technology_users, etc.)
|
||||
params: JSON 格式的查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 这里使用全局的 kg 实例,实际应在 Agent 初始化时注入
|
||||
try:
|
||||
params_dict = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"error": "Invalid JSON params"})
|
||||
|
||||
# 简化实现 - 实际应基于知识图谱查询
|
||||
result = {"query_type": query_type, "params": params_dict, "results": []}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
@tool()
|
||||
async def reason_about_path(start_entity: str, end_entity: str) -> str:
|
||||
"""推理两个实体之间的关系路径.
|
||||
|
||||
Args:
|
||||
start_entity: 起始实体名称
|
||||
end_entity: 目标实体名称
|
||||
|
||||
Returns:
|
||||
推理结果
|
||||
"""
|
||||
return json.dumps(
|
||||
{
|
||||
"start": start_entity,
|
||||
"end": end_entity,
|
||||
"reasoning": f"分析 {start_entity} 到 {end_entity} 的关系链...",
|
||||
"path": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestDocumentTools:
|
||||
"""文档工具测试."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_document(self, data_dir: Path, sample_article: str):
|
||||
"""测试文档读取工具."""
|
||||
result = tool_output(await read_document(str(data_dir / "documents" / "sample_article.md")))
|
||||
|
||||
assert "人工智能" in result
|
||||
assert "GitHub Copilot" in result
|
||||
assert "张三" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_document_not_found(self):
|
||||
"""测试读取不存在的文档."""
|
||||
result = tool_output(await read_document("/nonexistent/file.md"))
|
||||
|
||||
assert "Error" in result
|
||||
assert "not found" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_entities_from_article(self, sample_article: str):
|
||||
"""测试从文章中提取实体."""
|
||||
result = tool_output(await extract_entities(sample_article))
|
||||
entities = json.loads(result)
|
||||
|
||||
# 验证提取到实体
|
||||
assert len(entities) > 0
|
||||
|
||||
# 验证实体类型
|
||||
entity_names = [e["name"] for e in entities]
|
||||
assert "张三" in entity_names
|
||||
assert "TechCorp" in entity_names or "OpenAI" in entity_names
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestKnowledgeGraphTools:
|
||||
"""知识图谱工具测试."""
|
||||
|
||||
def test_knowledge_graph_initialization(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试知识图谱初始化."""
|
||||
# 验证实体数量
|
||||
entity = knowledge_graph.get_entity_by_name("张三")
|
||||
assert entity is not None
|
||||
assert entity["type"] == "Person"
|
||||
|
||||
# 验证公司实体
|
||||
company = knowledge_graph.get_entity_by_name("TechCorp")
|
||||
assert company is not None
|
||||
assert company["type"] == "Company"
|
||||
|
||||
def test_get_entities_by_type(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试按类型获取实体."""
|
||||
persons = knowledge_graph.get_entities_by_type("Person")
|
||||
assert len(persons) >= 3 # 张三、李四、李飞飞
|
||||
|
||||
technologies = knowledge_graph.get_entities_by_type("Technology")
|
||||
assert len(technologies) >= 2 # Python、OpenAI API
|
||||
|
||||
def test_get_relations(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试获取关系."""
|
||||
# 获取张三的所有关系
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
assert zhangsan is not None
|
||||
|
||||
relations = knowledge_graph.get_relations(from_id=zhangsan["id"])
|
||||
assert len(relations) >= 2 # works_for, uses
|
||||
|
||||
# 验证关系类型
|
||||
relation_types = [r["type"] for r in relations]
|
||||
assert "works_for" in relation_types
|
||||
assert "uses" in relation_types
|
||||
|
||||
def test_get_neighbors(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试获取邻居节点."""
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
assert zhangsan is not None
|
||||
|
||||
neighbors = knowledge_graph.get_neighbors(zhangsan["id"])
|
||||
assert len(neighbors) >= 2
|
||||
|
||||
# 验证邻居包含 TechCorp
|
||||
neighbor_names = [n["entity"]["name"] for n in neighbors]
|
||||
assert "TechCorp" in neighbor_names
|
||||
|
||||
def test_find_path(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试查找路径."""
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
techcorp = knowledge_graph.get_entity_by_name("TechCorp")
|
||||
|
||||
assert zhangsan is not None
|
||||
assert techcorp is not None
|
||||
|
||||
paths = knowledge_graph.find_path(zhangsan["id"], techcorp["id"])
|
||||
assert paths is not None
|
||||
assert len(paths) > 0
|
||||
|
||||
# 验证路径长度
|
||||
first_path = paths[0]
|
||||
assert len(first_path) == 2 # 张三 -> TechCorp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_knowledge_graph(self):
|
||||
"""测试知识图谱查询工具."""
|
||||
params = json.dumps({"entity_name": "张三"})
|
||||
result = tool_output(await query_knowledge_graph("person_relations", params))
|
||||
|
||||
data = json.loads(result)
|
||||
assert data["query_type"] == "person_relations"
|
||||
assert "params" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reason_about_path(self):
|
||||
"""测试路径推理工具."""
|
||||
result = tool_output(await reason_about_path("张三", "OpenAI"))
|
||||
|
||||
data = json.loads(result)
|
||||
assert data["start"] == "张三"
|
||||
assert data["end"] == "OpenAI"
|
||||
assert "reasoning" in data
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestDataIntegrity:
|
||||
"""数据完整性测试."""
|
||||
|
||||
def test_entities_json_valid(self, knowledge_graph_entities: dict):
|
||||
"""验证实体 JSON 格式正确."""
|
||||
assert "entities" in knowledge_graph_entities
|
||||
entities = knowledge_graph_entities["entities"]
|
||||
assert len(entities) > 0
|
||||
|
||||
# 验证每个实体都有必需的字段
|
||||
for entity in entities:
|
||||
assert "id" in entity
|
||||
assert "type" in entity
|
||||
assert "name" in entity
|
||||
|
||||
def test_relations_json_valid(
|
||||
self, knowledge_graph_relations: dict, knowledge_graph_entities: dict
|
||||
):
|
||||
"""验证关系 JSON 格式正确且引用的实体存在."""
|
||||
assert "relations" in knowledge_graph_relations
|
||||
relations = knowledge_graph_relations["relations"]
|
||||
|
||||
entity_ids = {e["id"] for e in knowledge_graph_entities["entities"]}
|
||||
|
||||
for relation in relations:
|
||||
assert "from" in relation
|
||||
assert "to" in relation
|
||||
assert "type" in relation
|
||||
|
||||
# 验证引用的实体存在
|
||||
assert relation["from"] in entity_ids, f"Entity {relation['from']} not found"
|
||||
assert relation["to"] in entity_ids, f"Entity {relation['to']} not found"
|
||||
|
||||
def test_graph_queries_yaml_valid(self, graph_queries: list):
|
||||
"""验证查询 YAML 格式正确."""
|
||||
assert len(graph_queries) > 0
|
||||
|
||||
for query in graph_queries:
|
||||
assert "id" in query
|
||||
assert "description" in query
|
||||
assert "query" in query
|
||||
assert "expected_results" in query
|
||||
|
||||
def test_documents_exist(self, data_dir: Path):
|
||||
"""验证测试文档存在且非空."""
|
||||
docs_dir = data_dir / "documents"
|
||||
|
||||
sample_article = docs_dir / "sample_article.md"
|
||||
assert sample_article.exists()
|
||||
assert sample_article.stat().st_size > 0
|
||||
|
||||
tech_spec = docs_dir / "technical_spec.md"
|
||||
assert tech_spec.exists()
|
||||
assert tech_spec.stat().st_size > 0
|
||||
|
||||
meeting_notes = docs_dir / "meeting_notes.txt"
|
||||
assert meeting_notes.exists()
|
||||
assert meeting_notes.stat().st_size > 0
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestAgentWithTools:
|
||||
"""Agent 集成工具测试."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_with_document_tools(self, mock_provider, data_dir: Path):
|
||||
"""测试带有文档工具的 Agent."""
|
||||
mock_provider.add_text_response("我已经读取了文档")
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=[read_document],
|
||||
system_prompt="你是一个文档助手,可以读取和分析文档。",
|
||||
)
|
||||
|
||||
response = await agent.run(f"请读取文档 {data_dir / 'documents' / 'sample_article.md'}")
|
||||
|
||||
assert "文档" in response or "读取" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_with_kg_tools(self, mock_provider):
|
||||
"""测试带有知识图谱工具的 Agent."""
|
||||
mock_provider.add_text_response("张三在 TechCorp 工作")
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=[query_knowledge_graph, reason_about_path],
|
||||
system_prompt="你是一个知识图谱助手,可以查询实体关系。",
|
||||
)
|
||||
|
||||
response = await agent.run("张三在哪里工作?")
|
||||
|
||||
assert response is not None
|
||||
assert len(response) > 0
|
||||
0
agentlite/tests/unit/__init__.py
Normal file
0
agentlite/tests/unit/__init__.py
Normal file
330
agentlite/tests/unit/test_config.py
Normal file
330
agentlite/tests/unit/test_config.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Unit tests for configuration models.
|
||||
|
||||
This module tests all Pydantic configuration models including
|
||||
ProviderConfig, ModelConfig, ToolConfig, and AgentConfig.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from agentlite import ProviderConfig, ModelConfig, AgentConfig
|
||||
|
||||
|
||||
class TestProviderConfig:
|
||||
"""Tests for ProviderConfig."""
|
||||
|
||||
def test_provider_config_valid(self):
|
||||
"""Test valid ProviderConfig creation."""
|
||||
config = ProviderConfig(
|
||||
type="openai",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="sk-test123",
|
||||
)
|
||||
assert config.type == "openai"
|
||||
assert config.base_url == "https://api.openai.com/v1"
|
||||
assert config.api_key.get_secret_value() == "sk-test123"
|
||||
|
||||
def test_provider_config_default_type(self):
|
||||
"""Test ProviderConfig with default type."""
|
||||
config = ProviderConfig(
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert config.type == "openai"
|
||||
|
||||
def test_provider_config_default_url(self):
|
||||
"""Test ProviderConfig with default base_url."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert config.base_url == "https://api.openai.com/v1"
|
||||
|
||||
def test_provider_config_invalid_url_http(self):
|
||||
"""Test ProviderConfig with invalid URL scheme."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProviderConfig(
|
||||
type="openai",
|
||||
base_url="ftp://invalid.com",
|
||||
api_key="sk-test",
|
||||
)
|
||||
assert "base_url must start with http:// or https://" in str(exc_info.value)
|
||||
|
||||
def test_provider_config_invalid_url_no_scheme(self):
|
||||
"""Test ProviderConfig with URL without scheme."""
|
||||
with pytest.raises(ValidationError):
|
||||
ProviderConfig(
|
||||
base_url="api.openai.com/v1",
|
||||
api_key="sk-test",
|
||||
)
|
||||
|
||||
def test_provider_config_custom_headers(self):
|
||||
"""Test ProviderConfig with custom headers."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
headers={"X-Custom": "value"},
|
||||
)
|
||||
assert config.headers == {"X-Custom": "value"}
|
||||
|
||||
def test_provider_config_default_headers(self):
|
||||
"""Test ProviderConfig default headers."""
|
||||
config = ProviderConfig(api_key="sk-test")
|
||||
assert config.headers == {}
|
||||
|
||||
def test_provider_config_timeout(self):
|
||||
"""Test ProviderConfig timeout."""
|
||||
config = ProviderConfig(
|
||||
api_key="sk-test",
|
||||
timeout=30.0,
|
||||
)
|
||||
assert config.timeout == 30.0
|
||||
|
||||
def test_provider_config_default_timeout(self):
|
||||
"""Test ProviderConfig default timeout."""
|
||||
config = ProviderConfig(api_key="sk-test")
|
||||
assert config.timeout == 60.0
|
||||
|
||||
def test_provider_config_api_key_is_secret_str(self):
|
||||
"""Test that api_key is stored as SecretStr."""
|
||||
config = ProviderConfig(api_key="sk-secret")
|
||||
# SecretStr should not expose value in repr/str
|
||||
assert "sk-secret" not in str(config.api_key)
|
||||
# But can get value explicitly
|
||||
assert config.api_key.get_secret_value() == "sk-secret"
|
||||
|
||||
|
||||
class TestModelConfig:
|
||||
"""Tests for ModelConfig."""
|
||||
|
||||
def test_model_config_valid(self):
|
||||
"""Test valid ModelConfig creation."""
|
||||
config = ModelConfig(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
)
|
||||
assert config.provider == "openai"
|
||||
assert config.model == "gpt-4"
|
||||
|
||||
def test_model_config_with_all_fields(self):
|
||||
"""Test ModelConfig with all optional fields."""
|
||||
config = ModelConfig(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
max_tokens=1000,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
capabilities={"streaming", "tool_calling"},
|
||||
)
|
||||
assert config.max_tokens == 1000
|
||||
assert config.temperature == 0.7
|
||||
assert config.top_p == 0.9
|
||||
assert config.capabilities == {"streaming", "tool_calling"}
|
||||
|
||||
def test_model_config_empty_provider(self):
|
||||
"""Test ModelConfig with empty provider."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ModelConfig(
|
||||
provider="",
|
||||
model="gpt-4",
|
||||
)
|
||||
assert "provider must not be empty" in str(exc_info.value)
|
||||
|
||||
def test_model_config_temperature_bounds(self):
|
||||
"""Test ModelConfig temperature validation bounds."""
|
||||
# Valid: 0.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", temperature=0.0)
|
||||
assert config.temperature == 0.0
|
||||
|
||||
# Valid: 2.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", temperature=2.0)
|
||||
assert config.temperature == 2.0
|
||||
|
||||
# Invalid: < 0
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", temperature=-0.1)
|
||||
|
||||
# Invalid: > 2
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", temperature=2.1)
|
||||
|
||||
def test_model_config_top_p_bounds(self):
|
||||
"""Test ModelConfig top_p validation bounds."""
|
||||
# Valid: 0.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", top_p=0.0)
|
||||
assert config.top_p == 0.0
|
||||
|
||||
# Valid: 1.0
|
||||
config = ModelConfig(provider="openai", model="gpt-4", top_p=1.0)
|
||||
assert config.top_p == 1.0
|
||||
|
||||
# Invalid: < 0
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", top_p=-0.1)
|
||||
|
||||
# Invalid: > 1
|
||||
with pytest.raises(ValidationError):
|
||||
ModelConfig(provider="openai", model="gpt-4", top_p=1.1)
|
||||
|
||||
def test_model_config_default_capabilities(self):
|
||||
"""Test ModelConfig default capabilities."""
|
||||
config = ModelConfig(provider="openai", model="gpt-4")
|
||||
assert config.capabilities == set()
|
||||
|
||||
|
||||
class TestAgentConfig:
|
||||
"""Tests for AgentConfig."""
|
||||
|
||||
def test_agent_config_minimal(self):
|
||||
"""Test AgentConfig with minimal required fields."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
assert config.name == "agent"
|
||||
assert config.system_prompt == "You are a helpful assistant."
|
||||
assert config.default_model == "default"
|
||||
|
||||
def test_agent_config_full(self):
|
||||
"""Test AgentConfig with all fields."""
|
||||
config = AgentConfig(
|
||||
name="my_agent",
|
||||
system_prompt="Custom system prompt",
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
max_history=50,
|
||||
)
|
||||
assert config.name == "my_agent"
|
||||
assert config.system_prompt == "Custom system prompt"
|
||||
assert config.default_model == "gpt4"
|
||||
assert config.max_history == 50
|
||||
|
||||
def test_agent_config_missing_default_model(self):
|
||||
"""Test AgentConfig with non-existent default_model."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="nonexistent",
|
||||
)
|
||||
assert "not found in models" in str(exc_info.value)
|
||||
|
||||
def test_agent_config_unknown_provider(self):
|
||||
"""Test AgentConfig with model referencing unknown provider."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="unknown", model="gpt-4")},
|
||||
)
|
||||
assert "unknown provider" in str(exc_info.value)
|
||||
|
||||
def test_agent_config_get_provider_config(self):
|
||||
"""Test get_provider_config method."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
provider_config = config.get_provider_config("gpt4")
|
||||
assert provider_config.api_key.get_secret_value() == "sk-test"
|
||||
|
||||
def test_agent_config_get_provider_config_default(self):
|
||||
"""Test get_provider_config with default model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
provider_config = config.get_provider_config()
|
||||
assert provider_config.api_key.get_secret_value() == "sk-test"
|
||||
|
||||
def test_agent_config_get_provider_config_not_found(self):
|
||||
"""Test get_provider_config with non-existent model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
config.get_provider_config("nonexistent")
|
||||
|
||||
def test_agent_config_get_model_config(self):
|
||||
"""Test get_model_config method."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
model_config = config.get_model_config("gpt4")
|
||||
assert model_config.model == "gpt-4"
|
||||
|
||||
def test_agent_config_get_model_config_default(self):
|
||||
"""Test get_model_config with default."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
|
||||
default_model="gpt4",
|
||||
)
|
||||
|
||||
model_config = config.get_model_config()
|
||||
assert model_config.model == "gpt-4"
|
||||
|
||||
def test_agent_config_get_model_config_not_found(self):
|
||||
"""Test get_model_config with non-existent model."""
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
|
||||
config.get_model_config("nonexistent")
|
||||
|
||||
def test_agent_config_multiple_providers(self):
|
||||
"""Test AgentConfig with multiple providers."""
|
||||
config = AgentConfig(
|
||||
providers={
|
||||
"openai": ProviderConfig(api_key="sk-openai"),
|
||||
"anthropic": ProviderConfig(
|
||||
type="anthropic",
|
||||
base_url="https://api.anthropic.com/v1",
|
||||
api_key="sk-anthropic",
|
||||
),
|
||||
},
|
||||
models={
|
||||
"default": ModelConfig(provider="openai", model="gpt-4"),
|
||||
"claude": ModelConfig(provider="anthropic", model="claude-3"),
|
||||
},
|
||||
)
|
||||
|
||||
assert len(config.providers) == 2
|
||||
assert len(config.models) == 2
|
||||
|
||||
def test_agent_config_max_history_validation(self):
|
||||
"""Test max_history validation."""
|
||||
# Valid: min=1
|
||||
config = AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=1,
|
||||
)
|
||||
assert config.max_history == 1
|
||||
|
||||
# Invalid: 0
|
||||
with pytest.raises(ValidationError):
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=0,
|
||||
)
|
||||
|
||||
# Invalid: negative
|
||||
with pytest.raises(ValidationError):
|
||||
AgentConfig(
|
||||
providers={"openai": ProviderConfig(api_key="sk-test")},
|
||||
models={"default": ModelConfig(provider="openai", model="gpt-4")},
|
||||
max_history=-1,
|
||||
)
|
||||
297
agentlite/tests/unit/test_message.py
Normal file
297
agentlite/tests/unit/test_message.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""Unit tests for message types.
|
||||
|
||||
This module tests all message-related types including ContentPart,
|
||||
Message, ToolCall, and their various subclasses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite import (
|
||||
ContentPart,
|
||||
Message,
|
||||
TextPart,
|
||||
ImageURLPart,
|
||||
AudioURLPart,
|
||||
ToolCall,
|
||||
ToolCallPart,
|
||||
)
|
||||
|
||||
|
||||
class TestContentPart:
|
||||
"""Tests for ContentPart base class and registry."""
|
||||
|
||||
def test_content_part_registry_auto_registers_subclasses(self):
|
||||
"""Test that ContentPart subclasses are auto-registered."""
|
||||
# All defined subclasses should be in registry
|
||||
assert "text" in ContentPart._ContentPart__content_part_registry
|
||||
assert "image_url" in ContentPart._ContentPart__content_part_registry
|
||||
assert "audio_url" in ContentPart._ContentPart__content_part_registry
|
||||
|
||||
def test_text_part_creation(self):
|
||||
"""Test basic TextPart creation."""
|
||||
part = TextPart(text="Hello, world!")
|
||||
assert part.type == "text"
|
||||
assert part.text == "Hello, world!"
|
||||
|
||||
def test_text_part_model_dump(self):
|
||||
"""Test TextPart serialization."""
|
||||
part = TextPart(text="Hello")
|
||||
dumped = part.model_dump()
|
||||
assert dumped == {"type": "text", "text": "Hello"}
|
||||
|
||||
def test_text_part_merge_success(self):
|
||||
"""Test successful text merge during streaming."""
|
||||
part1 = TextPart(text="Hello ")
|
||||
part2 = TextPart(text="world!")
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.text == "Hello world!"
|
||||
|
||||
def test_text_part_merge_failure(self):
|
||||
"""Test merge failure with incompatible types."""
|
||||
text_part = TextPart(text="Hello")
|
||||
|
||||
# Try to merge with non-TextPart
|
||||
result = text_part.merge_in_place("not a part")
|
||||
assert result is False
|
||||
assert text_part.text == "Hello" # Unchanged
|
||||
|
||||
|
||||
class TestImageURLPart:
|
||||
"""Tests for ImageURLPart."""
|
||||
|
||||
def test_image_url_part_creation(self):
|
||||
"""Test ImageURLPart creation."""
|
||||
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
|
||||
assert part.type == "image_url"
|
||||
assert part.image_url.url == "https://example.com/image.png"
|
||||
|
||||
def test_image_url_part_with_detail(self):
|
||||
"""Test ImageURLPart with detail parameter."""
|
||||
part = ImageURLPart(
|
||||
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
|
||||
)
|
||||
assert part.image_url.detail == "high"
|
||||
|
||||
def test_image_url_part_default_detail(self):
|
||||
"""Test ImageURLPart default detail is None."""
|
||||
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
|
||||
assert part.image_url.detail is None
|
||||
|
||||
|
||||
class TestAudioURLPart:
|
||||
"""Tests for AudioURLPart."""
|
||||
|
||||
def test_audio_url_part_creation(self):
|
||||
"""Test AudioURLPart creation."""
|
||||
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
|
||||
assert part.type == "audio_url"
|
||||
assert part.audio_url.url == "https://example.com/audio.mp3"
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
"""Tests for ToolCall."""
|
||||
|
||||
def test_tool_call_creation(self):
|
||||
"""Test ToolCall creation."""
|
||||
call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
|
||||
)
|
||||
assert call.type == "function"
|
||||
assert call.id == "call_123"
|
||||
assert call.function.name == "add"
|
||||
assert call.function.arguments == '{"a": 1, "b": 2}'
|
||||
|
||||
def test_tool_call_merge_with_part(self):
|
||||
"""Test ToolCall merging with ToolCallPart."""
|
||||
call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
|
||||
)
|
||||
part = ToolCallPart(arguments_part=', "b": 2}')
|
||||
|
||||
result = call.merge_in_place(part)
|
||||
|
||||
assert result is True
|
||||
assert call.function.arguments == '{"a": 1, "b": 2}'
|
||||
|
||||
def test_tool_call_merge_failure(self):
|
||||
"""Test ToolCall merge failure with incompatible types."""
|
||||
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
|
||||
|
||||
result = call.merge_in_place("not a part")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestToolCallPart:
|
||||
"""Tests for ToolCallPart."""
|
||||
|
||||
def test_tool_call_part_creation(self):
|
||||
"""Test ToolCallPart creation."""
|
||||
part = ToolCallPart(arguments_part='{"a": 1}')
|
||||
assert part.arguments_part == '{"a": 1}'
|
||||
|
||||
def test_tool_call_part_none(self):
|
||||
"""Test ToolCallPart with None arguments."""
|
||||
part = ToolCallPart(arguments_part=None)
|
||||
assert part.arguments_part is None
|
||||
|
||||
def test_tool_call_part_merge(self):
|
||||
"""Test ToolCallPart merging."""
|
||||
part1 = ToolCallPart(arguments_part='{"a":')
|
||||
part2 = ToolCallPart(arguments_part=" 1}")
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.arguments_part == '{"a": 1}'
|
||||
|
||||
def test_tool_call_part_merge_none(self):
|
||||
"""Test ToolCallPart merge when self is None."""
|
||||
part1 = ToolCallPart(arguments_part=None)
|
||||
part2 = ToolCallPart(arguments_part='{"a": 1}')
|
||||
|
||||
result = part1.merge_in_place(part2)
|
||||
|
||||
assert result is True
|
||||
assert part1.arguments_part == '{"a": 1}'
|
||||
|
||||
|
||||
class TestMessage:
|
||||
"""Tests for Message."""
|
||||
|
||||
def test_message_string_content_coercion(self):
|
||||
"""Test that string content is coerced to TextPart."""
|
||||
msg = Message(role="user", content="Hello!")
|
||||
|
||||
assert len(msg.content) == 1
|
||||
assert isinstance(msg.content[0], TextPart)
|
||||
assert msg.content[0].text == "Hello!"
|
||||
|
||||
def test_message_part_content(self):
|
||||
"""Test Message with ContentPart content."""
|
||||
part = TextPart(text="Hello!")
|
||||
msg = Message(role="user", content=part)
|
||||
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Hello!"
|
||||
|
||||
def test_message_list_content(self):
|
||||
"""Test Message with list of ContentParts."""
|
||||
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
|
||||
msg = Message(role="user", content=parts)
|
||||
|
||||
assert len(msg.content) == 2
|
||||
|
||||
def test_message_extract_text(self):
|
||||
"""Test text extraction from message."""
|
||||
msg = Message(role="user", content="Hello world!")
|
||||
|
||||
assert msg.extract_text() == "Hello world!"
|
||||
|
||||
def test_message_extract_text_with_separator(self):
|
||||
"""Test text extraction with custom separator."""
|
||||
parts = [TextPart(text="Hello"), TextPart(text="world!")]
|
||||
msg = Message(role="user", content=parts)
|
||||
|
||||
assert msg.extract_text(sep=" ") == "Hello world!"
|
||||
assert msg.extract_text(sep="-") == "Hello-world!"
|
||||
|
||||
def test_message_has_tool_calls_false(self):
|
||||
"""Test has_tool_calls returns False when no tool calls."""
|
||||
msg = Message(role="assistant", content="Hello!")
|
||||
assert msg.has_tool_calls() is False
|
||||
|
||||
def test_message_has_tool_calls_true(self):
|
||||
"""Test has_tool_calls returns True when tool calls present."""
|
||||
tool_call = ToolCall(
|
||||
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
|
||||
)
|
||||
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
|
||||
assert msg.has_tool_calls() is True
|
||||
|
||||
def test_message_has_tool_calls_empty_list(self):
|
||||
"""Test has_tool_calls with empty tool_calls list."""
|
||||
msg = Message(role="assistant", content="Hello!", tool_calls=[])
|
||||
assert msg.has_tool_calls() is False
|
||||
|
||||
def test_message_tool_response(self):
|
||||
"""Test message with tool response."""
|
||||
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
|
||||
assert msg.role == "tool"
|
||||
assert msg.tool_call_id == "call_123"
|
||||
|
||||
def test_message_serialization(self):
|
||||
"""Test Message serialization with model_dump."""
|
||||
msg = Message(role="user", content="Hello!")
|
||||
dumped = msg.model_dump()
|
||||
|
||||
assert dumped["role"] == "user"
|
||||
assert "content" in dumped
|
||||
|
||||
def test_message_all_roles(self):
|
||||
"""Test Message creation with all valid roles."""
|
||||
for role in ["system", "user", "assistant", "tool"]:
|
||||
msg = Message(role=role, content="Test")
|
||||
assert msg.role == role
|
||||
|
||||
|
||||
class TestPolymorphicContentPart:
|
||||
"""Tests for polymorphic ContentPart validation."""
|
||||
|
||||
def test_polymorphic_validation_text(self):
|
||||
"""Test that text type validates to TextPart."""
|
||||
data = {"type": "text", "text": "Hello"}
|
||||
part = ContentPart.model_validate(data)
|
||||
|
||||
assert isinstance(part, TextPart)
|
||||
assert part.text == "Hello"
|
||||
|
||||
def test_polymorphic_validation_image(self):
|
||||
"""Test that image_url type validates to ImageURLPart."""
|
||||
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
|
||||
part = ContentPart.model_validate(data)
|
||||
|
||||
assert isinstance(part, ImageURLPart)
|
||||
assert part.image_url.url == "https://example.com/image.png"
|
||||
|
||||
def test_polymorphic_validation_unknown_type(self):
|
||||
"""Test validation with unknown type raises error."""
|
||||
data = {"type": "unknown_type", "content": "test"}
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown content part type"):
|
||||
ContentPart.model_validate(data)
|
||||
|
||||
def test_polymorphic_validation_no_type(self):
|
||||
"""Test validation without type raises error."""
|
||||
data = {"content": "test"}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ContentPart.model_validate(data)
|
||||
|
||||
|
||||
class TestMessageEdgeCases:
|
||||
"""Tests for edge cases in Message handling."""
|
||||
|
||||
def test_empty_string_content(self):
|
||||
"""Test Message with empty string content."""
|
||||
msg = Message(role="user", content="")
|
||||
assert msg.content[0].text == ""
|
||||
|
||||
def test_message_with_name(self):
|
||||
"""Test Message with name field."""
|
||||
msg = Message(role="user", content="Hello", name="user1")
|
||||
assert msg.name == "user1"
|
||||
|
||||
def test_message_history_isolation(self):
|
||||
"""Test that history modifications don't affect original."""
|
||||
msg = Message(role="user", content="Hello")
|
||||
|
||||
# Modify the content list
|
||||
msg.content.append(TextPart(text="Extra"))
|
||||
|
||||
# Original should be modified (it's the same object)
|
||||
assert len(msg.content) == 2
|
||||
168
agentlite/tests/unit/test_provider.py
Normal file
168
agentlite/tests/unit/test_provider.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Unit tests for provider protocol and exceptions.
|
||||
|
||||
This module tests the ChatProvider protocol, StreamedMessage protocol,
|
||||
and all exception types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite.provider import (
|
||||
TokenUsage,
|
||||
ChatProviderError,
|
||||
APIConnectionError,
|
||||
APITimeoutError,
|
||||
APIStatusError,
|
||||
APIEmptyResponseError,
|
||||
ChatProvider,
|
||||
StreamedMessage,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenUsage:
|
||||
"""Tests for TokenUsage."""
|
||||
|
||||
def test_token_usage_creation(self):
|
||||
"""Test TokenUsage creation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.input_tokens == 100
|
||||
assert usage.output_tokens == 50
|
||||
assert usage.cached_tokens == 0 # Default
|
||||
|
||||
def test_token_usage_with_cached(self):
|
||||
"""Test TokenUsage with cached tokens."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
assert usage.cached_tokens == 20
|
||||
|
||||
def test_token_usage_total(self):
|
||||
"""Test total token calculation."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50)
|
||||
assert usage.total == 150
|
||||
|
||||
def test_token_usage_total_with_cached(self):
|
||||
"""Test total with cached tokens (not included in total)."""
|
||||
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
|
||||
# Total is input + output, cached is tracked separately
|
||||
assert usage.total == 150
|
||||
|
||||
|
||||
class TestChatProviderError:
|
||||
"""Tests for ChatProviderError hierarchy."""
|
||||
|
||||
def test_base_error_creation(self):
|
||||
"""Test base ChatProviderError creation."""
|
||||
error = ChatProviderError("Something went wrong")
|
||||
assert error.message == "Something went wrong"
|
||||
assert str(error) == "Something went wrong"
|
||||
|
||||
def test_api_connection_error(self):
|
||||
"""Test APIConnectionError creation."""
|
||||
error = APIConnectionError("Connection failed")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Connection failed"
|
||||
|
||||
def test_api_timeout_error(self):
|
||||
"""Test APITimeoutError creation."""
|
||||
error = APITimeoutError("Request timed out")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Request timed out"
|
||||
|
||||
def test_api_status_error(self):
|
||||
"""Test APIStatusError creation."""
|
||||
error = APIStatusError(429, "Rate limit exceeded")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.status_code == 429
|
||||
assert error.message == "Rate limit exceeded"
|
||||
|
||||
def test_api_status_error_different_codes(self):
|
||||
"""Test APIStatusError with different status codes."""
|
||||
codes = [400, 401, 403, 404, 429, 500, 502, 503]
|
||||
for code in codes:
|
||||
error = APIStatusError(code, f"Error {code}")
|
||||
assert error.status_code == code
|
||||
|
||||
def test_api_empty_response_error(self):
|
||||
"""Test APIEmptyResponseError creation."""
|
||||
error = APIEmptyResponseError("Empty response from API")
|
||||
assert isinstance(error, ChatProviderError)
|
||||
assert error.message == "Empty response from API"
|
||||
|
||||
def test_exception_hierarchy(self):
|
||||
"""Test that all exceptions inherit from ChatProviderError."""
|
||||
errors = [
|
||||
APIConnectionError("test"),
|
||||
APITimeoutError("test"),
|
||||
APIStatusError(500, "test"),
|
||||
APIEmptyResponseError("test"),
|
||||
]
|
||||
for error in errors:
|
||||
assert isinstance(error, ChatProviderError)
|
||||
|
||||
|
||||
class TestChatProviderProtocol:
|
||||
"""Tests for ChatProvider protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that ChatProvider is runtime checkable."""
|
||||
# ChatProvider should have @runtime_checkable
|
||||
from typing import runtime_checkable
|
||||
|
||||
assert hasattr(ChatProvider, "__protocol_attrs__")
|
||||
|
||||
def test_mock_provider_implements_protocol(self, mock_provider):
|
||||
"""Test that MockProvider implements ChatProvider."""
|
||||
assert isinstance(mock_provider, ChatProvider)
|
||||
|
||||
def test_mock_provider_has_model_name(self, mock_provider):
|
||||
"""Test that mock provider has model_name property."""
|
||||
assert hasattr(mock_provider, "model_name")
|
||||
assert isinstance(mock_provider.model_name, str)
|
||||
|
||||
def test_mock_provider_has_generate_method(self, mock_provider):
|
||||
"""Test that mock provider has generate method."""
|
||||
assert hasattr(mock_provider, "generate")
|
||||
assert callable(mock_provider.generate)
|
||||
|
||||
|
||||
class TestStreamedMessageProtocol:
|
||||
"""Tests for StreamedMessage protocol."""
|
||||
|
||||
def test_protocol_is_runtime_checkable(self):
|
||||
"""Test that StreamedMessage is runtime checkable."""
|
||||
assert hasattr(StreamedMessage, "__protocol_attrs__")
|
||||
|
||||
def test_mock_streamed_message_implements_protocol(self):
|
||||
"""Test that MockStreamedMessage implements StreamedMessage."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert isinstance(stream, StreamedMessage)
|
||||
|
||||
def test_streamed_message_has_id_property(self):
|
||||
"""Test that streamed message has id property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "id")
|
||||
assert stream.id == "mock-msg-123"
|
||||
|
||||
def test_streamed_message_has_usage_property(self):
|
||||
"""Test that streamed message has usage property."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "usage")
|
||||
assert stream.usage is not None
|
||||
assert isinstance(stream.usage, TokenUsage)
|
||||
|
||||
def test_streamed_message_is_async_iterable(self):
|
||||
"""Test that streamed message is async iterable."""
|
||||
from tests.conftest import MockStreamedMessage
|
||||
from agentlite import TextPart
|
||||
|
||||
stream = MockStreamedMessage([TextPart(text="Hello")])
|
||||
assert hasattr(stream, "__aiter__")
|
||||
210
agentlite/tests/unit/test_tool.py
Normal file
210
agentlite/tests/unit/test_tool.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for tool decorator and CallableTool.
|
||||
|
||||
This module tests the @tool() decorator and related tool functionality.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentlite.tool import tool, CallableTool, ToolOk, ToolError
|
||||
|
||||
|
||||
class TestToolDecorator:
|
||||
"""Tests for the @tool() decorator."""
|
||||
|
||||
def test_tool_decorator_basic(self):
|
||||
"""Test basic tool decorator functionality."""
|
||||
|
||||
@tool()
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert isinstance(add, CallableTool)
|
||||
assert add.name == "add"
|
||||
assert add.description == "Add two numbers."
|
||||
assert add.parameters["type"] == "object"
|
||||
assert "a" in add.parameters["properties"]
|
||||
assert "b" in add.parameters["properties"]
|
||||
assert add.parameters["properties"]["a"]["type"] == "number"
|
||||
assert add.parameters["properties"]["b"]["type"] == "number"
|
||||
assert add.parameters["required"] == ["a", "b"]
|
||||
|
||||
def test_tool_decorator_with_default_params(self):
|
||||
"""Test tool decorator with default parameters."""
|
||||
|
||||
@tool()
|
||||
async def greet(name: str, greeting: str = "Hello") -> str:
|
||||
"""Greet someone."""
|
||||
return f"{greeting}, {name}!"
|
||||
|
||||
assert greet.name == "greet"
|
||||
assert "name" in greet.parameters["required"]
|
||||
assert "greeting" not in greet.parameters["required"]
|
||||
|
||||
def test_tool_decorator_custom_name(self):
|
||||
"""Test tool decorator with custom name."""
|
||||
|
||||
@tool(name="custom_add")
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add.name == "custom_add"
|
||||
|
||||
def test_tool_decorator_custom_description(self):
|
||||
"""Test tool decorator with custom description."""
|
||||
|
||||
@tool(description="Custom description")
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
assert add.description == "Custom description"
|
||||
|
||||
def test_tool_decorator_no_docstring(self):
|
||||
"""Test tool decorator with no docstring."""
|
||||
|
||||
@tool()
|
||||
async def no_doc(a: float) -> float:
|
||||
return a
|
||||
|
||||
assert no_doc.description == "No description provided"
|
||||
|
||||
def test_tool_decorator_param_types(self):
|
||||
"""Test tool decorator with various parameter types."""
|
||||
|
||||
@tool()
|
||||
async def multi_types(
|
||||
s: str,
|
||||
i: int,
|
||||
f: float,
|
||||
b: bool,
|
||||
) -> dict:
|
||||
"""Multiple types."""
|
||||
return {"s": s, "i": i, "f": f, "b": b}
|
||||
|
||||
props = multi_types.parameters["properties"]
|
||||
assert props["s"]["type"] == "string"
|
||||
assert props["i"]["type"] == "integer"
|
||||
assert props["f"]["type"] == "number"
|
||||
assert props["b"]["type"] == "boolean"
|
||||
|
||||
def test_tool_decorator_no_type_hints(self):
|
||||
"""Test tool decorator with no type hints."""
|
||||
|
||||
@tool()
|
||||
async def no_types(param) -> str:
|
||||
"""No type hints."""
|
||||
return str(param)
|
||||
|
||||
assert no_types.parameters["properties"]["param"]["type"] == "string"
|
||||
|
||||
|
||||
class TestToolDecoratorExecution:
|
||||
"""Tests for tool decorator execution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_success(self):
|
||||
"""Test successful tool execution."""
|
||||
|
||||
@tool()
|
||||
async def add(a: float, b: float) -> float:
|
||||
"""Add two numbers."""
|
||||
return a + b
|
||||
|
||||
result = await add(1.0, 2.0)
|
||||
assert isinstance(result, ToolOk)
|
||||
assert result.output == "3.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_error(self):
|
||||
"""Test tool execution with error."""
|
||||
|
||||
@tool()
|
||||
async def divide(a: float, b: float) -> float:
|
||||
"""Divide two numbers."""
|
||||
return a / b
|
||||
|
||||
result = await divide(1.0, 0.0)
|
||||
assert isinstance(result, ToolError)
|
||||
assert "division by zero" in result.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_kwargs(self):
|
||||
"""Test tool execution with keyword arguments."""
|
||||
|
||||
@tool()
|
||||
async def greet(name: str, greeting: str = "Hello") -> str:
|
||||
"""Greet someone."""
|
||||
return f"{greeting}, {name}!"
|
||||
|
||||
result = await greet(name="World", greeting="Hi")
|
||||
assert isinstance(result, ToolOk)
|
||||
assert result.output == "Hi, World!"
|
||||
|
||||
|
||||
class TestToolDecoratorMemorixBug:
|
||||
"""Tests for the specific bug reported by Memorix project."""
|
||||
|
||||
def test_tool_decorator_memorix_case(self):
|
||||
"""Test the exact case from Memorix bug report.
|
||||
|
||||
This test verifies that the @tool() decorator works correctly
|
||||
with async functions that have string and float parameters.
|
||||
"""
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok"}
|
||||
|
||||
assert isinstance(add_memory, CallableTool)
|
||||
assert add_memory.name == "add_memory"
|
||||
assert add_memory.description == "存储记忆"
|
||||
|
||||
# Check parameters schema
|
||||
params = add_memory.parameters
|
||||
assert params["type"] == "object"
|
||||
assert "content" in params["properties"]
|
||||
assert "importance" in params["properties"]
|
||||
assert params["properties"]["content"]["type"] == "string"
|
||||
assert params["properties"]["importance"]["type"] == "number"
|
||||
|
||||
# content is required (no default), importance is optional
|
||||
assert "content" in params["required"]
|
||||
assert "importance" not in params["required"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_decorator_memorix_execution(self):
|
||||
"""Test execution of the Memorix case."""
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok", "content": content, "importance": importance}
|
||||
|
||||
result = await add_memory("test content", 0.8)
|
||||
assert isinstance(result, ToolOk)
|
||||
assert "ok" in result.output
|
||||
|
||||
def test_tool_decorator_can_be_used_in_agent(self):
|
||||
"""Test that decorated tools can be used with Agent.
|
||||
|
||||
This is an integration-style test to ensure the decorated tool
|
||||
has all required attributes for Agent usage.
|
||||
"""
|
||||
from agentlite import Agent, OpenAIProvider
|
||||
|
||||
@tool()
|
||||
async def add_memory(content: str, importance: float = 0.5) -> dict:
|
||||
"""存储记忆"""
|
||||
return {"status": "ok"}
|
||||
|
||||
# Verify the tool has the base property required by Agent
|
||||
assert hasattr(add_memory, "base")
|
||||
base_tool = add_memory.base
|
||||
assert base_tool.name == "add_memory"
|
||||
assert base_tool.description == "存储记忆"
|
||||
assert base_tool.parameters == add_memory.parameters
|
||||
98
agentlite/tests/utils.py
Normal file
98
agentlite/tests/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Test utilities and helpers for AgentLite tests.
|
||||
|
||||
This module provides utility functions and helpers used across test modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def run_async(coro: asyncio.Coroutine[Any, Any, T]) -> T:
|
||||
"""Run an async coroutine and return the result.
|
||||
|
||||
This is a helper for tests that need to run async code synchronously.
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run.
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
"""
|
||||
return await coro
|
||||
|
||||
|
||||
def run_sync(coro: asyncio.Coroutine[Any, Any, T]) -> T:
|
||||
"""Run an async coroutine synchronously.
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run.
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
"""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
async def collect_stream(stream) -> list[Any]:
|
||||
"""Collect all items from an async stream into a list.
|
||||
|
||||
Args:
|
||||
stream: The async stream to collect from.
|
||||
|
||||
Returns:
|
||||
List of all items from the stream.
|
||||
"""
|
||||
items = []
|
||||
async for item in stream:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
async def collect_stream_text(stream) -> str:
|
||||
"""Collect all text from an async text stream.
|
||||
|
||||
Args:
|
||||
stream: The async stream to collect from.
|
||||
|
||||
Returns:
|
||||
Concatenated text from all items.
|
||||
"""
|
||||
from agentlite import TextPart
|
||||
|
||||
text_parts = []
|
||||
async for item in stream:
|
||||
if isinstance(item, TextPart):
|
||||
text_parts.append(item.text)
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def create_tool_schema(
|
||||
name: str,
|
||||
description: str,
|
||||
properties: dict[str, Any],
|
||||
required: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a JSON schema for a tool.
|
||||
|
||||
Args:
|
||||
name: Tool name.
|
||||
description: Tool description.
|
||||
properties: JSON schema properties.
|
||||
required: List of required property names.
|
||||
|
||||
Returns:
|
||||
JSON schema for the tool.
|
||||
"""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
}
|
||||
if required:
|
||||
schema["required"] = required
|
||||
return schema
|
||||
Reference in New Issue
Block a user