feat: add a subagent frame

This commit is contained in:
tcmofashi
2026-04-03 22:15:53 +08:00
parent ce580d1f8b
commit 185361f2c3
72 changed files with 13062 additions and 0 deletions

View File

331
agentlite/tests/conftest.py Normal file
View File

@@ -0,0 +1,331 @@
"""Test configuration and shared fixtures for AgentLite tests.
This module provides pytest configuration and fixtures that are shared
across all test modules.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator, Sequence
from typing import Any, Optional
import pytest
from agentlite import (
Agent,
ContentPart,
Message,
TextPart,
ToolCall,
ToolOk,
ToolError,
tool,
)
from agentlite.provider import ChatProvider, StreamedMessage, TokenUsage
from agentlite.tool import Tool, ToolResult
# =============================================================================
# pytest Configuration
# =============================================================================
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line("markers", "unit: Unit tests")
config.addinivalue_line("markers", "integration: Integration tests")
config.addinivalue_line("markers", "scenario: Real-world scenario tests")
config.addinivalue_line("markers", "slow: Slow tests that may take time")
# =============================================================================
# Mock Provider Implementation
# =============================================================================
class MockStreamedMessage:
"""Mock streamed message for testing."""
def __init__(self, parts: list[ContentPart]):
self._parts = parts
self._id = "mock-msg-123"
self._usage = TokenUsage(input_tokens=10, output_tokens=5)
def __aiter__(self) -> AsyncIterator[ContentPart]:
"""Return async iterator over parts."""
return self._iter_parts()
async def _iter_parts(self) -> AsyncIterator[ContentPart]:
"""Iterate over parts."""
for part in self._parts:
yield part
@property
def id(self) -> Optional[str]:
"""Message ID."""
return self._id
@property
def usage(self) -> Optional[TokenUsage]:
"""Token usage."""
return self._usage
class MockProvider:
"""Mock provider for testing AgentLite without real API calls.
This provider simulates OpenAI API responses and allows:
- Configuring response sequences
- Simulating tool calls
- Simulating errors
- Tracking all calls for verification
Example:
provider = MockProvider()
provider.add_text_response("Hello!")
provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=provider)
response = await agent.run("Hi")
# Verify calls
assert len(provider.calls) == 1
assert provider.calls[0]["system_prompt"] == "You are helpful."
"""
def __init__(self):
self.responses: list[dict[str, Any]] = []
self.calls: list[dict[str, Any]] = []
self.model = "mock-model"
def add_text_response(self, text: str) -> None:
"""Add a text response to the queue."""
self.responses.append({"type": "text", "content": text})
def add_text_responses(self, *texts: str) -> None:
"""Add multiple text responses to the queue."""
for text in texts:
self.add_text_response(text)
def add_tool_call(self, name: str, arguments: dict[str, Any], result: str) -> None:
"""Add a tool call response to the queue."""
self.responses.append(
{"type": "tool_call", "name": name, "arguments": arguments, "result": result}
)
def add_tool_calls(self, calls: list[dict[str, Any]]) -> None:
"""Add multiple tool calls to the queue."""
for call in calls:
self.add_tool_call(call["name"], call["arguments"], call.get("result", ""))
def add_error(self, error: Exception) -> None:
"""Add an error response to the queue."""
self.responses.append({"type": "error", "error": error})
def clear_responses(self) -> None:
"""Clear all pending responses."""
self.responses.clear()
@property
def model_name(self) -> str:
"""Model name."""
return self.model
async def generate(
self,
system_prompt: str,
tools: Sequence[Tool],
history: Sequence[Message],
) -> StreamedMessage:
"""Generate a mock response."""
self.calls.append(
{
"system_prompt": system_prompt,
"tools": list(tools),
"history": list(history),
}
)
if not self.responses:
return MockStreamedMessage([TextPart(text="Mock response")])
response = self.responses.pop(0)
if response["type"] == "error":
raise response["error"]
elif response["type"] == "text":
return MockStreamedMessage([TextPart(text=response["content"])])
elif response["type"] == "tool_call":
return MockStreamedMessage(
[
ToolCall(
id="call_123",
function=ToolCall.FunctionBody(
name=response["name"], arguments=json.dumps(response["arguments"])
),
)
]
)
else:
return MockStreamedMessage([TextPart(text="Unknown response type")])
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_provider():
"""Create a mock provider with no responses configured."""
return MockProvider()
@pytest.fixture
def mock_provider_with_response():
"""Create a mock provider that returns a simple text response."""
provider = MockProvider()
provider.add_text_response("Hello!")
return provider
@pytest.fixture
def mock_provider_with_sequence():
"""Create a mock provider with multiple responses configured."""
provider = MockProvider()
provider.add_text_responses("Response 1", "Response 2", "Response 3")
return provider
# =============================================================================
# Message Fixtures
# =============================================================================
@pytest.fixture
def sample_text_message():
"""Create a sample text message."""
return Message(role="user", content="Hello!")
@pytest.fixture
def sample_assistant_message():
"""Create a sample assistant message."""
return Message(role="assistant", content="Hi there!")
@pytest.fixture
def sample_tool_call():
"""Create a sample tool call."""
return ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
@pytest.fixture
def sample_tool_message():
"""Create a sample tool response message."""
return Message(role="tool", content="3", tool_call_id="call_123")
# =============================================================================
# Tool Fixtures
# =============================================================================
@pytest.fixture
def add_tool():
"""Create a simple add tool."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
return add
@pytest.fixture
def multiply_tool():
"""Create a multiply tool."""
@tool()
async def multiply(a: float, b: float) -> float:
"""Multiply two numbers."""
return a * b
return multiply
@pytest.fixture
def error_tool():
"""Create a tool that raises an error."""
@tool()
async def error() -> str:
"""Always raises an error."""
raise ValueError("Test error")
return error
@pytest.fixture
def slow_tool():
"""Create a tool that takes some time."""
@tool()
async def slow_operation(duration: float = 0.1) -> str:
"""Simulate a slow operation."""
await asyncio.sleep(duration)
return f"Completed after {duration}s"
return slow_operation
# =============================================================================
# Agent Fixtures
# =============================================================================
@pytest.fixture
async def simple_agent(mock_provider):
"""Create a simple agent with mocked provider."""
return Agent(provider=mock_provider)
@pytest.fixture
async def agent_with_tools(mock_provider, add_tool):
"""Create an agent with tools."""
return Agent(provider=mock_provider, tools=[add_tool])
@pytest.fixture
async def agent_with_multiple_tools(mock_provider, add_tool, multiply_tool):
"""Create an agent with multiple tools."""
return Agent(provider=mock_provider, tools=[add_tool, multiply_tool])
# =============================================================================
# Utility Fixtures
# =============================================================================
@pytest.fixture
def sample_conversation():
"""Create a sample conversation history."""
return [
Message(role="user", content="Hello!"),
Message(role="assistant", content="Hi there! How can I help?"),
Message(role="user", content="What is 2+2?"),
Message(role="assistant", content="2+2=4"),
]
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()

View File

View File

@@ -0,0 +1,286 @@
"""Integration tests for Agent class.
This module tests the Agent class with mocked providers to verify
core functionality without making real API calls.
"""
from __future__ import annotations
import pytest
from agentlite import Agent, TextPart
@pytest.mark.integration
class TestAgentInitialization:
"""Tests for Agent initialization."""
def test_agent_initialization(self, mock_provider):
"""Test basic agent creation."""
agent = Agent(provider=mock_provider)
assert agent.provider is mock_provider
assert agent.system_prompt == "You are a helpful assistant."
assert agent.max_iterations == 80
assert agent.history == []
def test_agent_with_custom_system_prompt(self, mock_provider):
"""Test agent creation with custom system prompt."""
agent = Agent(provider=mock_provider, system_prompt="You are a specialized assistant.")
assert agent.system_prompt == "You are a specialized assistant."
def test_agent_with_tools(self, mock_provider, add_tool):
"""Test agent creation with tools."""
agent = Agent(provider=mock_provider, tools=[add_tool])
assert len(agent.tools.tools) == 1
assert agent.tools.tools[0].name == "add"
def test_agent_with_custom_max_iterations(self, mock_provider):
"""Test agent with custom max_iterations."""
agent = Agent(provider=mock_provider, max_iterations=5)
assert agent.max_iterations == 5
@pytest.mark.integration
class TestAgentRun:
"""Tests for Agent.run() method."""
@pytest.mark.asyncio
async def test_agent_run_simple(self, mock_provider):
"""Test simple non-streaming run."""
mock_provider.add_text_response("Hello there!")
agent = Agent(provider=mock_provider)
response = await agent.run("Hi")
assert response == "Hello there!"
@pytest.mark.asyncio
async def test_agent_run_adds_to_history(self, mock_provider):
"""Test that run adds messages to history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
# History should have user message and assistant response
assert len(agent.history) == 2
assert agent.history[0].role == "user"
assert agent.history[0].extract_text() == "Hello"
assert agent.history[1].role == "assistant"
@pytest.mark.asyncio
async def test_agent_run_multiple_messages(self, mock_provider):
"""Test multiple runs accumulate history."""
mock_provider.add_text_responses("Response 1", "Response 2")
agent = Agent(provider=mock_provider)
await agent.run("Message 1")
await agent.run("Message 2")
# Should have 4 messages total
assert len(agent.history) == 4
assert agent.history[0].role == "user"
assert agent.history[1].role == "assistant"
assert agent.history[2].role == "user"
assert agent.history[3].role == "assistant"
@pytest.mark.asyncio
async def test_agent_run_tracks_calls(self, mock_provider):
"""Test that provider.generate is called during run."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
assert len(mock_provider.calls) == 1
call = mock_provider.calls[0]
assert call["system_prompt"] == "You are a helpful assistant."
assert len(call["history"]) == 1 # User message
@pytest.mark.integration
class TestAgentGenerate:
"""Tests for Agent.generate() method."""
@pytest.mark.asyncio
async def test_agent_generate_returns_message(self, mock_provider):
"""Test that generate returns a Message."""
mock_provider.add_text_response("Generated response")
agent = Agent(provider=mock_provider)
message = await agent.generate("Hello")
assert message.role == "assistant"
assert message.extract_text() == "Generated response"
@pytest.mark.asyncio
async def test_agent_generate_without_tool_loop(self, mock_provider):
"""Test that generate doesn't do tool calling loop."""
# Add tool call response
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=mock_provider, tools=[])
message = await agent.generate("Calculate 1+2")
# Should return the tool call without executing it
assert message.has_tool_calls()
assert len(message.tool_calls) == 1
assert message.tool_calls[0].function.name == "add"
@pytest.mark.asyncio
async def test_agent_generate_adds_to_history(self, mock_provider):
"""Test that generate adds response to history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.generate("Hello")
assert len(agent.history) == 2
assert agent.history[1].role == "assistant"
@pytest.mark.integration
class TestAgentHistory:
"""Tests for Agent history management."""
@pytest.mark.asyncio
async def test_agent_history_property_returns_copy(self, mock_provider):
"""Test that history property returns a copy."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
history = agent.history
history.clear() # Modify the copy
# Original should still have messages
assert len(agent.history) == 2
@pytest.mark.asyncio
async def test_agent_clear_history(self, mock_provider):
"""Test clearing history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
agent.clear_history()
assert agent.history == []
@pytest.mark.asyncio
async def test_agent_add_message(self, mock_provider):
"""Test manually adding a message."""
agent = Agent(provider=mock_provider)
from agentlite import Message
agent.add_message(Message(role="user", content="Manual message"))
assert len(agent.history) == 1
assert agent.history[0].extract_text() == "Manual message"
@pytest.mark.integration
class TestAgentWithTools:
"""Tests for Agent with tools."""
@pytest.mark.asyncio
async def test_agent_with_tools_initialization(self, mock_provider, add_tool):
"""Test agent initialization with tools."""
agent = Agent(
provider=mock_provider, tools=[add_tool], system_prompt="You have access to tools."
)
assert len(agent.tools.tools) == 1
# Run to verify tools are passed to provider
mock_provider.add_text_response("I have tools available")
await agent.run("Hello")
# Check that tools were passed to provider
assert len(mock_provider.calls) == 1
assert len(mock_provider.calls[0]["tools"]) == 1
@pytest.mark.asyncio
async def test_agent_tool_call_execution(self, mock_provider, add_tool):
"""Test that agent executes tool calls."""
# First response: tool call
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
# Second response: text after tool result
mock_provider.add_text_response("The sum is 3")
agent = Agent(provider=mock_provider, tools=[add_tool])
response = await agent.run("What is 1+2?")
assert "3" in response
# Should have made 2 calls to provider
assert len(mock_provider.calls) == 2
@pytest.mark.integration
class TestAgentMaxIterations:
"""Tests for max_iterations behavior."""
@pytest.mark.asyncio
async def test_agent_respects_max_iterations(self, mock_provider, add_tool):
"""Test that agent stops after max_iterations."""
# Always return tool calls to trigger iteration limit
for _ in range(10):
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=mock_provider, tools=[add_tool], max_iterations=3)
response = await agent.run("Calculate")
# Should stop after max_iterations
assert len(mock_provider.calls) <= 3
assert "Maximum tool call iterations reached" in response
@pytest.mark.asyncio
async def test_agent_no_iterations_for_simple_response(self, mock_provider):
"""Test that simple responses don't count as iterations."""
mock_provider.add_text_response("Simple response")
agent = Agent(provider=mock_provider, max_iterations=1)
response = await agent.run("Hello")
assert response == "Simple response"
@pytest.mark.integration
class TestAgentStreaming:
"""Tests for streaming mode."""
@pytest.mark.asyncio
async def test_agent_run_streaming(self, mock_provider):
"""Test streaming run."""
mock_provider.add_text_response("Streamed response")
agent = Agent(provider=mock_provider)
stream = await agent.run("Hello", stream=True)
# Collect stream
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert "".join(chunks) == "Streamed response"
@pytest.mark.asyncio
async def test_agent_streaming_adds_to_history(self, mock_provider):
"""Test that streaming adds messages to history."""
mock_provider.add_text_response("Response")
agent = Agent(provider=mock_provider)
stream = await agent.run("Hello", stream=True)
async for _ in stream:
pass
assert len(agent.history) == 2

View File

@@ -0,0 +1,348 @@
"""Integration tests for AgentLite with real API.
This script runs comprehensive tests against the real OpenAI API.
Requires OPENAI_API_KEY environment variable to be set.
Usage:
export OPENAI_API_KEY="sk-..."
python tests/integration/test_with_api.py
"""
import asyncio
import os
import sys
from pathlib import Path
import pytest
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
from agentlite import Agent, OpenAIProvider, LLMClient, llm_complete
from agentlite.skills import discover_skills, SkillTool, index_skills_by_name
from agentlite.tools import ConfigurableToolset
# Test configuration
TEST_MODEL = "gpt-4o-mini" # Use mini for cost efficiency
HAS_OPENAI_API_KEY = bool(os.environ.get("OPENAI_API_KEY"))
pytestmark = pytest.mark.skipif(
not HAS_OPENAI_API_KEY, reason="OPENAI_API_KEY is required to run integration tests"
)
def get_provider():
"""Get OpenAI provider with API key."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("❌ OPENAI_API_KEY not set!")
print("Please set your OpenAI API key:")
print(" export OPENAI_API_KEY='sk-...'")
sys.exit(1)
return OpenAIProvider(api_key=api_key, model=TEST_MODEL)
async def test_basic_agent():
"""Test 1: Basic Agent functionality."""
print("\n" + "=" * 60)
print("Test 1: Basic Agent Functionality")
print("=" * 60)
try:
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant. Be concise.",
)
response = await agent.run("What is 2+2?")
print(f"✅ Agent responded: {response[:100]}...")
assert "4" in response, "Expected '4' in response"
print("✅ Basic Agent test PASSED")
return True
except Exception as e:
print(f"❌ Basic Agent test FAILED: {e}")
return False
async def test_agent_with_tools():
"""Test 2: Agent with tool suite."""
print("\n" + "=" * 60)
print("Test 2: Agent with Tool Suite")
print("=" * 60)
try:
from agentlite.tools import ToolSuiteConfig, ReadFile, Glob
provider = get_provider()
# Create toolset with file tools
config = ToolSuiteConfig()
toolset = ConfigurableToolset(config, work_dir=Path.cwd())
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant with file access.",
tools=toolset.tools,
)
print(f"✅ Agent created with {len(agent.tools.tools)} tools")
# Test simple query (without requiring file access)
response = await agent.run("List the Python files in the current directory")
print(f"✅ Agent with tools responded: {response[:100]}...")
print("✅ Agent with Tools test PASSED")
return True
except Exception as e:
print(f"❌ Agent with Tools test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_llm_client():
"""Test 3: LLMClient functionality."""
print("\n" + "=" * 60)
print("Test 3: LLMClient Functionality")
print("=" * 60)
try:
provider = get_provider()
client = LLMClient(provider=provider)
response = await client.complete(
user_prompt="What is the capital of France?",
system_prompt="You are a helpful assistant. Be concise.",
)
print(f"✅ LLMClient responded: {response.content[:100]}...")
assert "Paris" in response.content, "Expected 'Paris' in response"
print("✅ LLMClient test PASSED")
return True
except Exception as e:
print(f"❌ LLMClient test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_llm_streaming():
"""Test 4: LLM streaming."""
print("\n" + "=" * 60)
print("Test 4: LLM Streaming")
print("=" * 60)
try:
provider = get_provider()
client = LLMClient(provider=provider)
chunks = []
async for chunk in client.stream(
user_prompt="Count from 1 to 3",
system_prompt="You are a helpful assistant.",
):
chunks.append(chunk)
print(f" Chunk: {chunk[:20]}...")
full_response = "".join(chunks)
print(f"✅ Streamed response: {full_response[:100]}...")
print("✅ LLM Streaming test PASSED")
return True
except Exception as e:
print(f"❌ LLM Streaming test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_subagents():
"""Test 5: Subagent functionality."""
print("\n" + "=" * 60)
print("Test 5: Subagent Functionality")
print("=" * 60)
try:
from agentlite.labor_market import LaborMarket
from agentlite.tools.multiagent.task import Task
provider = get_provider()
# Create parent agent
parent = Agent(
provider=provider,
system_prompt="You are a coordinator agent.",
name="coordinator",
)
# Create subagent
coder = Agent(
provider=provider,
system_prompt="You are a coding specialist. Write clean, simple code.",
name="coder",
)
# Add subagent to parent
parent.add_subagent("coder", coder, "Writes code")
# Add Task tool
parent.tools.add(Task(labor_market=parent.labor_market))
print(f"✅ Created parent with {len(parent.labor_market)} subagent(s)")
print(f" Subagents: {parent.labor_market.list_subagents()}")
print("✅ Subagent test PASSED")
return True
except Exception as e:
print(f"❌ Subagent test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_skills():
"""Test 6: Skills functionality."""
print("\n" + "=" * 60)
print("Test 6: Skills Functionality")
print("=" * 60)
try:
# Discover example skills
skills_dir = Path(__file__).parent.parent.parent / "examples" / "skills"
if not skills_dir.exists():
print("⚠️ Skills directory not found, skipping")
return True
skills = discover_skills(skills_dir)
print(f"✅ Discovered {len(skills)} skill(s)")
for skill in skills:
print(f" - {skill.name} ({skill.type})")
if len(skills) == 0:
print("⚠️ No skills found, skipping further tests")
return True
# Test with agent
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant.",
)
skill_index = index_skills_by_name(skills)
skill_tool = SkillTool(skill_index, parent_agent=agent)
agent.tools.add(skill_tool)
print(f"✅ Added SkillTool to agent")
print("✅ Skills test PASSED")
return True
except Exception as e:
print(f"❌ Skills test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_conversation_history():
"""Test 7: Conversation history."""
print("\n" + "=" * 60)
print("Test 7: Conversation History")
print("=" * 60)
try:
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant.",
)
# First message
response1 = await agent.run("My name is Alice")
print(f"✅ Response 1: {response1[:50]}...")
# Second message (should remember context)
response2 = await agent.run("What is my name?")
print(f"✅ Response 2: {response2[:50]}...")
assert "Alice" in response2, "Expected agent to remember name"
print("✅ Conversation History test PASSED")
return True
except Exception as e:
print(f"❌ Conversation History test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def run_all_tests():
"""Run all integration tests."""
print("\n" + "=" * 60)
print("AgentLite Integration Tests with Real API")
print("=" * 60)
print(f"Model: {TEST_MODEL}")
# Check API key
if not os.environ.get("OPENAI_API_KEY"):
print("\n❌ OPENAI_API_KEY not set!")
print("\nTo run these tests, set your OpenAI API key:")
print(" export OPENAI_API_KEY='sk-...'")
print("\nGet your API key from: https://platform.openai.com/api-keys")
return []
results = []
# Run all tests
results.append(("Basic Agent", await test_basic_agent()))
results.append(("Agent with Tools", await test_agent_with_tools()))
results.append(("LLMClient", await test_llm_client()))
results.append(("LLM Streaming", await test_llm_streaming()))
results.append(("Subagents", await test_subagents()))
results.append(("Skills", await test_skills()))
results.append(("Conversation History", await test_conversation_history()))
# Print summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✅ PASSED" if result else "❌ FAILED"
print(f"{status}: {name}")
print(f"\n{passed}/{total} tests passed")
if passed == total:
print("\n🎉 All tests passed!")
else:
print(f"\n⚠️ {total - passed} test(s) failed")
return results
if __name__ == "__main__":
results = asyncio.run(run_all_tests())
# Exit with error code if any tests failed
if results and not all(r for _, r in results):
sys.exit(1)

View File

View File

View File

@@ -0,0 +1,141 @@
"""Debug script to find CLI test hang cause."""
from __future__ import annotations
import os
import sys
import asyncio
import signal
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
from agentlite import Agent, OpenAIProvider
from agentlite.tools.shell.shell import Shell, Params
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
async def test_shell_directly():
"""Test shell tool without agent."""
print("\n=== Test 1: Shell tool directly ===")
shell = Shell(timeout=10)
# Use Params dataclass
result = await shell(Params(command="echo 'Hello'", timeout=5))
print(f"Result: {result}")
print(f"Output: {result.output if hasattr(result, 'output') else result}")
return True
async def test_agent_no_tools():
"""Test agent without tools."""
print("\n=== Test 2: Agent without tools ===")
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
print("SILICONFLOW_API_KEY not set")
return False
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=30.0,
)
agent = Agent(
provider=provider,
system_prompt="Reply briefly in one word.",
max_iterations=3,
)
print("Sending message to LLM...")
try:
response = await asyncio.wait_for(
agent.run("Say hello."),
timeout=60.0,
)
print(f"Response: {response[:100]}...")
return True
except asyncio.TimeoutError:
print("TIMEOUT in agent without tools!")
return False
async def test_agent_with_shell():
"""Test agent with shell tool - the problematic case."""
print("\n=== Test 3: Agent WITH shell tool ===")
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
print("SILICONFLOW_API_KEY not set")
return False
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=60.0,
)
agent = Agent(
provider=provider,
system_prompt="You are a shell assistant. Execute commands when asked. Keep responses brief.",
tools=[Shell(timeout=10)],
max_iterations=5, # Limit iterations
)
print("Sending message with tool request...")
print("This is where it might hang...")
try:
response = await asyncio.wait_for(
agent.run("Run 'echo test' and tell me the result."),
timeout=120.0,
)
print(f"Response: {response}")
return True
except asyncio.TimeoutError:
print("TIMEOUT! Agent hung for 120 seconds")
# Check history to see what happened
print(f"\nHistory length: {len(agent.history)}")
for i, msg in enumerate(agent.history[-5:]):
content_preview = str(msg.content)[:100] if msg.content else "None"
print(f" [{i}] {msg.role}: {content_preview}...")
return False
async def main():
"""Run all tests."""
print("=" * 60)
print("CLI Debug Test - Finding the hang cause")
print("=" * 60)
results = []
# Test 1: Shell directly
r1 = await test_shell_directly()
results.append(("Shell directly", r1))
print(f"Result: {'PASS' if r1 else 'FAIL'}")
# Test 2: Agent without tools
r2 = await test_agent_no_tools()
results.append(("Agent no tools", r2))
print(f"Result: {'PASS' if r2 else 'FAIL'}")
# Test 3: Agent with shell (the problem)
r3 = await test_agent_with_shell()
results.append(("Agent with shell", r3))
print(f"Result: {'PASS' if r3 else 'FAIL'}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "✅ PASS" if passed else "❌ FAIL"
print(f" {name}: {status}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,221 @@
"""Debug script with detailed logging to find CLI test hang cause."""
from __future__ import annotations
import os
import sys
import asyncio
import logging
import time
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("debug")
# SiliconFlow DeepSeek-V3 (known good function calling support)
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Pro/deepseek-ai/DeepSeek-V3.2"
SILICONFLOW_API_KEY = "sk-eaxfgkkcuatochftxpevkyvltghigsrclzjzalybmaqycual"
async def main():
from agentlite import Agent, OpenAIProvider
from agentlite.tools.shell.shell import Shell
from agentlite.message import Message
logger.info("=" * 60)
logger.info("CLI Debug Test with DeepSeek-V3 (SiliconFlow)")
logger.info("=" * 60)
api_key = os.environ.get("SILICONFLOW_API_KEY") or SILICONFLOW_API_KEY
if not api_key:
logger.error("SILICONFLOW_API_KEY not set")
return
logger.info(f"Using model: {SILICONFLOW_MODEL}")
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=30.0,
)
agent = Agent(
provider=provider,
system_prompt="You are a shell assistant. Execute commands when asked. Reply briefly.",
tools=[Shell(timeout=10)],
max_iterations=5,
)
start_time = time.time()
message = "Run 'echo test' and tell me the result."
logger.info(f"\n=== Starting Agent Run ===")
logger.info(f"Message: {message}")
logger.info(f"Max iterations: {agent.max_iterations}")
logger.info(f"Tools: {[t.name for t in agent.tools.tools]}")
agent._history.append(Message(role="user", content=message))
iterations = 0
final_response = None
while iterations < agent.max_iterations:
iterations += 1
elapsed = time.time() - start_time
logger.info(f"\n{'=' * 50}")
logger.info(f"ITERATION {iterations}/{agent.max_iterations} (elapsed: {elapsed:.1f}s)")
logger.info(f"{'=' * 50}")
# Step 1: Call Provider
logger.info(">>> Step 1: Calling provider.generate()...")
step_start = time.time()
try:
stream = await asyncio.wait_for(
provider.generate(
system_prompt=agent.system_prompt,
tools=agent.tools.tools,
history=agent._history,
),
timeout=60.0,
)
logger.info(f"<<< Provider returned stream in {time.time() - step_start:.2f}s")
except asyncio.TimeoutError:
logger.error("!!! Provider call TIMEOUT after 60s")
final_response = "ERROR: Provider timeout"
break
# Step 2: Collect stream parts
logger.info(">>> Step 2: Collecting stream parts...")
step_start = time.time()
from agentlite.message import TextPart, ToolCall, ContentPart
response_parts = []
tool_calls = []
chunk_count = 0
try:
async for part in stream:
chunk_count += 1
if chunk_count % 10 == 0:
logger.debug(f" Received chunk #{chunk_count}")
if isinstance(part, ToolCall):
tool_calls.append(part)
logger.info(
f" ToolCall received: {part.function.name if hasattr(part, 'function') else part}"
)
elif isinstance(part, ContentPart):
response_parts.append(part)
if isinstance(part, TextPart):
logger.debug(f" Text: {part.text[:50]}...")
logger.info(
f"<<< Stream finished in {time.time() - step_start:.2f}s, {chunk_count} chunks"
)
except asyncio.TimeoutError:
logger.error("!!! Stream reading TIMEOUT")
final_response = "ERROR: Stream timeout"
break
except Exception as e:
logger.error(f"!!! Stream error: {type(e).__name__}: {e}")
final_response = f"ERROR: Stream error - {e}"
break
# Extract text
response_text = ""
for part in response_parts:
if isinstance(part, TextPart):
response_text += part.text
logger.info(f"Response text ({len(response_text)} chars): {response_text[:100]}...")
logger.info(f"Tool calls: {len(tool_calls)}")
# Add to history
agent._history.append(
Message(
role="assistant",
content=response_parts,
tool_calls=tool_calls if tool_calls else None,
)
)
# Step 3: Check if done
if not tool_calls:
elapsed = time.time() - start_time
logger.info(f"\n=== Agent completed in {elapsed:.2f}s, {iterations} iterations ===")
final_response = response_text
break
# Step 4: Execute tool calls
logger.info(f"\n>>> Step 3: Executing {len(tool_calls)} tool calls...")
step_start = time.time()
for i, tc in enumerate(tool_calls):
func_name = tc.function.name if hasattr(tc, "function") else str(tc)
func_args = tc.function.arguments if hasattr(tc, "function") else ""
logger.info(f" Tool #{i + 1}: {func_name}")
logger.info(f" Args: {func_args[:200]}...")
try:
result = await asyncio.wait_for(
agent.tools.handle(tc),
timeout=30.0,
)
output = result.output if hasattr(result, "output") else str(result)
is_error = result.is_error if hasattr(result, "is_error") else False
logger.info(
f" Result: is_error={is_error}, output_len={len(output) if output else 0}"
)
output_preview = output[:100] if output else "None"
logger.info(f" Output preview: {output_preview}...")
except asyncio.TimeoutError:
logger.error(f" !!! Tool execution TIMEOUT")
output = "Tool execution timed out"
is_error = True
except Exception as e:
logger.error(f" !!! Tool error: {type(e).__name__}: {e}")
output = str(e)
is_error = True
# Add tool result to history
agent._history.append(
Message(
role="tool",
content=output,
tool_call_id=tc.id if hasattr(tc, "id") else f"tc_{i}",
)
)
logger.info(f"<<< Tool execution finished in {time.time() - step_start:.2f}s")
# Check overall timeout
elapsed = time.time() - start_time
if elapsed > 90:
logger.warning(f"!!! Overall timeout approaching ({elapsed:.1f}s)")
final_response = f"Timeout after {iterations} iterations"
break
if iterations >= agent.max_iterations:
logger.warning(f"!!! Max iterations reached ({agent.max_iterations})")
final_response = f"Max iterations ({agent.max_iterations}) reached"
logger.info(f"\n{'=' * 60}")
logger.info(f"FINAL RESULT:")
logger.info(f"{'=' * 60}")
logger.info(f"{final_response}")
logger.info(f"Total iterations: {iterations}")
logger.info(f"Total time: {time.time() - start_time:.2f}s")
logger.info(f"History length: {len(agent._history)}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,349 @@
"""End-to-end test for complex CLI operations with real API.
This test simulates a realistic complex CLI task where an agent:
1. Explores project structure using shell commands
2. Searches for specific patterns using grep/glob
3. Reads relevant files
4. Creates analysis reports
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
"""
from __future__ import annotations
import asyncio
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, OpenAIProvider
from agentlite.tools import (
ConfigurableToolset,
ToolSuiteConfig,
Shell,
ReadFile,
WriteFile,
Glob,
Grep,
)
# =============================================================================
# Configuration from model_config.toml
# =============================================================================
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
def get_siliconflow_provider() -> OpenAIProvider | None:
"""Create OpenAIProvider for SiliconFlow API."""
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
return None
return OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
)
@pytest.fixture
def real_provider():
"""Create real SiliconFlow provider."""
provider = get_siliconflow_provider()
if provider is None:
pytest.skip("SILICONFLOW_API_KEY not set")
return provider
@pytest.fixture
def test_project():
"""Create a mock project structure for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir) / "test_project"
project_dir.mkdir()
# Create project structure
(project_dir / "src").mkdir()
(project_dir / "src" / "utils").mkdir()
(project_dir / "tests").mkdir()
(project_dir / "docs").mkdir()
# Create source files
(project_dir / "src" / "main.py").write_text('''"""Main module."""
from src.utils.helper import process_data
from src.utils.logger import setup_logger
def main():
"""Main entry point."""
logger = setup_logger()
data = [1, 2, 3, 4, 5]
result = process_data(data)
logger.info(f"Result: {result}")
return result
if __name__ == "__main__":
main()
''')
(project_dir / "src" / "__init__.py").write_text('"""Source package."""')
(project_dir / "src" / "utils" / "helper.py").write_text('''"""Helper utilities."""
def process_data(data: list) -> list:
"""Process input data."""
return [x * 2 for x in data]
def validate_data(data: list) -> bool:
"""Validate data format."""
return all(isinstance(x, (int, float)) for x in data)
''')
(project_dir / "src" / "utils" / "logger.py").write_text('''"""Logging utilities."""
import logging
def setup_logger(name: str = "app") -> logging.Logger:
"""Setup application logger."""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
return logger
''')
(project_dir / "src" / "utils" / "__init__.py").write_text('"""Utils package."""')
# Create test files
(project_dir / "tests" / "test_helper.py").write_text('''"""Tests for helper module."""
from src.utils.helper import process_data, validate_data
def test_process_data():
assert process_data([1, 2, 3]) == [2, 4, 6]
def test_validate_data():
assert validate_data([1, 2, 3]) == True
assert validate_data(["a", "b"]) == False
''')
# Create documentation
(project_dir / "docs" / "README.md").write_text("""# Test Project
A sample project for testing CLI operations.
## Structure
- `src/` - Source code
- `tests/` - Unit tests
- `docs/` - Documentation
""")
(project_dir / "README.md").write_text("""# Test Project
Simple data processing project.
## Usage
```bash
python -m src.main
```
""")
yield project_dir
@pytest.mark.scenario
@pytest.mark.slow
class TestComplexCLITasks:
"""End-to-end tests with complex CLI operations."""
@pytest.mark.asyncio
async def test_explore_project_structure(self, real_provider, test_project):
"""Test exploring project structure using CLI tools.
Task: Use shell commands to explore the project structure,
then summarize what files exist.
"""
# Create toolset with Shell tool
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个项目分析助手。使用 Shell 工具执行命令来探索项目结构。"
"请使用 find、ls、tree 等命令来了解项目。"
),
max_iterations=5, # Limit iterations to prevent hanging
)
# Add overall timeout to prevent infinite hanging
try:
response = await asyncio.wait_for(
agent.run(
f"探索项目目录 {test_project} 的结构,列出所有文件和目录,并总结项目的组织方式。"
),
timeout=120.0, # 2 minute overall timeout
)
except asyncio.TimeoutError:
pytest.fail("Agent timed out after 120 seconds - possible infinite loop")
assert response, "Agent should return a response"
print(f"\n[项目结构探索结果]:\n{response}\n")
# Verify response mentions key files
response_lower = response.lower()
assert any(
word in response_lower for word in ["src", "tests", "main.py", "helper", "logger"]
), "Response should mention project files"
@pytest.mark.asyncio
async def test_search_and_analyze_code(self, real_provider, test_project):
"""Test searching for patterns and analyzing code.
Task: Use grep/glob to find specific patterns,
read the files, and create an analysis report.
"""
# Create toolset with all file tools
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
file_tools=ToolSuiteConfig().file_tools,
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个代码分析助手。使用 Glob、Grep、ReadFile 等工具来搜索和分析代码。"
"请使用 Shell 工具执行 grep、find 等命令。"
),
)
response = await agent.run(
f"在项目 {test_project} 中搜索所有包含 'def ' 的 Python 文件,"
f"列出找到的函数定义,并创建一个函数清单文件保存到 {test_project}/functions.txt。"
)
assert response, "Agent should return a response"
print(f"\n[代码搜索分析结果]:\n{response}\n")
# Check if analysis file was created
functions_file = test_project / "functions.txt"
if functions_file.exists():
content = functions_file.read_text()
print(f"\n[函数清单文件]:\n{content}\n")
assert len(content) > 0, "Functions file should not be empty"
@pytest.mark.asyncio
async def test_complex_multi_step_task(self, real_provider, test_project):
"""Test a complex multi-step CLI task.
Task:
1. Find all Python files using shell
2. Search for TODO comments using grep
3. Read files with TODOs
4. Create a summary report
"""
# Add some TODO comments
todo_file = test_project / "src" / "utils" / "todo_items.py"
todo_file.write_text('''"""Module with TODO items."""
# TODO: Implement error handling
def risky_operation(data):
"""Perform a risky operation."""
return data / 0 # This will fail
# TODO: Add caching mechanism
def expensive_computation(n):
"""Perform expensive computation."""
return sum(range(n))
# FIXME: Memory leak in this function
def process_large_file(path):
"""Process a large file."""
with open(path) as f:
return f.read()
''')
# Create comprehensive toolset
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
file_tools=ToolSuiteConfig().file_tools,
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个项目维护助手。"
"使用 Shell 工具执行命令(如 find、grep、ls 等)。"
"使用 ReadFile 读取文件内容。"
"使用 WriteFile 创建新文件。"
"请一步一步完成任务。"
),
)
response = await agent.run(
f"请完成以下任务:\n"
f"1. 使用 'find' 命令找出项目 {test_project} 中所有的 .py 文件\n"
f"2. 使用 'grep' 命令搜索所有包含 'TODO''FIXME' 的行\n"
f"3. 读取包含 TODO 的文件内容\n"
f"4. 创建一个 TODO 报告文件,保存到 {test_project}/todo_report.txt"
)
assert response, "Agent should return a response"
print(f"\n[复杂任务结果]:\n{response}\n")
# Verify report was created
report_file = test_project / "todo_report.txt"
if report_file.exists():
content = report_file.read_text()
print(f"\n[TODO 报告]:\n{content}\n")
@pytest.mark.asyncio
async def test_shell_pipes_and_chains(self, real_provider, test_project):
"""Test complex shell commands with pipes and chains.
Task: Use shell pipes to perform complex data processing.
"""
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个 Shell 命令专家。"
"使用复杂的 Shell 命令(管道、重定向、条件执行等)来完成任务。"
),
)
response = await agent.run(
f"在项目目录 {test_project} 中执行以下操作:\n"
f"1. 使用 'find . -name \"*.py\" | xargs wc -l' 统计所有 Python 文件的总行数\n"
f'2. 使用 \'grep -r "def " --include="*.py" | wc -l\' 统计函数定义数量\n'
f"3. 使用 'ls -la' 查看目录详情\n"
f"报告你的发现。"
)
assert response, "Agent should return a response"
print(f"\n[Shell 管道命令结果]:\n{response}\n")
# Verify response contains relevant information
response_lower = response.lower()
assert any(
word in response_lower for word in ["", "line", "函数", "function", "文件", "file"]
), "Response should mention analysis results"

View File

@@ -0,0 +1,374 @@
"""End-to-end scenario test for file operations.
This test simulates a realistic scenario where an agent:
1. Reads a file
2. Explains its content
3. Creates a new file with analysis results
This is a meaningful e2e test that demonstrates the agent's ability to
orchestrate multiple tool calls in sequence.
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, TextPart, tool
# =============================================================================
# File Operation Tools
# =============================================================================
@tool()
async def read_file(file_path: str) -> str:
"""Read the content of a file.
Args:
file_path: Path to the file to read.
Returns:
The content of the file as a string.
Raises:
FileNotFoundError: If the file does not exist.
"""
with open(file_path) as f:
return f.read()
@tool()
async def write_file(file_path: str, content: str) -> str:
"""Write content to a file, creating it if it doesn't exist.
Args:
file_path: Path to the file to write.
content: Content to write to the file.
Returns:
Success message confirming the file was written.
"""
# Create parent directories if they don't exist
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(content)
return f"File successfully written to {file_path}"
@tool()
async def list_files(directory: str) -> str:
"""List all files in a directory.
Args:
directory: Path to the directory to list.
Returns:
A newline-separated list of file names in the directory.
"""
files = os.listdir(directory)
return "\n".join(files)
# =============================================================================
# E2E Test
# =============================================================================
@pytest.mark.scenario
class TestFileOperationsScenario:
"""End-to-end test for file read/write operations."""
@pytest.mark.asyncio
async def test_read_explain_and_write(self, mock_provider):
"""Test a complete workflow: read file -> explain -> write results."""
# Setup: Create a temporary file with content
with tempfile.TemporaryDirectory() as tmpdir:
# Create a source file to read
source_file = os.path.join(tmpdir, "source.txt")
source_content = """Project Overview
================
This is a sample project document for testing.
Features:
- Feature A: Does something useful
- Feature B: Does something else
- Feature C: The most important feature
Conclusion: This project demonstrates file operations.
"""
with open(source_file, "w") as f:
f.write(source_content)
# Configure mock provider responses
# The agent should:
# 1. Read the file
# 2. Summarize it
# 3. Write the summary to a new file
mock_provider.add_text_response(
f"I'll read the file at {source_file} and analyze it for you."
)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a helpful file analysis assistant.",
)
# Step 1: Agent reads and analyzes the file
mock_provider.clear_responses()
mock_provider.add_tool_call(
"read_file",
{"file_path": source_file},
source_content,
)
# Agent analyzes the content
mock_provider.add_text_response(
"I've read the file. It's a project overview document with 3 features. "
"Let me create a summary file."
)
# Step 2: Agent writes summary to a new file
summary_file = os.path.join(tmpdir, "summary.txt")
expected_summary = """Project Summary
================
This is a sample project with 3 main features:
- Feature A, - Feature B, - Feature C
The most important feature is Feature C.
"""
mock_provider.clear_responses()
mock_provider.add_tool_call(
"write_file",
{
"file_path": summary_file,
"content": expected_summary,
},
f"File successfully written to {summary_file}",
)
mock_provider.add_text_response(f"I've created a summary at {summary_file}")
# Execute the agent
response = await agent.run(
f"Please read {source_file}, analyze it, and create a summary file at {summary_file}"
)
# Verify the interaction
assert "summary" in response.lower()
# Verify the provider was called correctly
assert len(mock_provider.calls) >= 1
@pytest.mark.asyncio
async def test_list_files_scenario(self, mock_provider):
"""Test listing files in a directory."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create some test files
for i in range(3):
with open(os.path.join(tmpdir, f"file{i}.txt"), "w") as f:
f.write(f"Content {i}")
# Configure agent to list files
mock_provider.add_tool_call(
"list_files",
{"directory": tmpdir},
"file0.txt\nfile1.txt\nfile2.txt",
)
mock_provider.add_text_response(
f"I found 3 files in {tmpdir}: file0.txt, file1.txt, file2.txt"
)
agent = Agent(
provider=mock_provider,
tools=[list_files],
system_prompt="You are a file system assistant.",
)
response = await agent.run(f"List all files in {tmpdir}")
assert "3 files" in response
@pytest.mark.asyncio
async def test_multi_step_file_workflow(self, mock_provider):
"""Test a complex multi-step file workflow.
Scenario:
1. List files in directory
2. Read each file
3. Create a combined report
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Create test files
files_content = {
"report1.txt": "Sales increased by 20%",
"report2.txt": "Customer satisfaction at 85%",
"report3.txt": "Bug fixes: 15 resolved",
}
for name, content in files_content.items():
with open(os.path.join(tmpdir, name), "w") as f:
f.write(content)
# Configure agent responses for multi-step workflow
tools = [read_file, write_file, list_files]
# Step 1: List files
mock_provider.add_tool_call(
"list_files",
{"directory": tmpdir},
"report1.txt\nreport2.txt\nreport3.txt",
)
# Step 2: Read all files
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report1.txt")},
"Sales increased by 20%",
)
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report2.txt")},
"Customer satisfaction at 85%",
)
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report3.txt")},
"Bug fixes: 15 resolved",
)
# Step 3: Write combined report
combined_report = """Combined Report
================
1. Sales: Increased by 20%
2. Customer Satisfaction: 85%
3. Development: 15 bugs resolved
"""
mock_provider.add_tool_call(
"write_file",
{
"file_path": os.path.join(tmpdir, "combined_report.txt"),
"content": combined_report,
},
f"File successfully written to {os.path.join(tmpdir, 'combined_report.txt')}",
)
mock_provider.add_text_response(
"I've created a combined report summarizing all three reports."
)
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a report analyst assistant.",
)
response = await agent.run(
f"List all files in {tmpdir}, read them all, and create a combined report at combined_report.txt"
)
assert "combined report" in response.lower()
# =============================================================================
# Additional Tools for Extended Scenarios
# =============================================================================
@tool()
async def count_words(file_path: str) -> str:
"""Count the number of words in a file.
Args:
file_path: Path to the file to analyze.
Returns:
The word count as a string.
"""
with open(file_path) as f:
content = f.read()
word_count = len(content.split())
return f"Word count: {word_count}"
@tool()
async def append_to_file(file_path: str, content: str) -> str:
"""Append content to an existing file.
Args:
file_path: Path to the file to append to.
content: Content to append.
Returns:
Success message.
"""
with open(file_path, "a") as f:
f.write("\n" + content)
return f"Content appended to {file_path}"
@pytest.mark.scenario
class TestExtendedFileOperations:
"""Extended scenarios with more file operations."""
@pytest.mark.asyncio
async def test_read_count_and_append(self, mock_provider):
"""Test reading a file, counting words, and appending a note."""
with tempfile.TemporaryDirectory() as tmpdir:
source_file = os.path.join(tmpdir, "document.txt")
with open(source_file, "w") as f:
f.write("This is a test document with several words in it.")
tools = [read_file, write_file, count_words, append_to_file]
# Step 1: Read file
mock_provider.add_tool_call(
"read_file",
{"file_path": source_file},
"This is a test document with several words in it.",
)
# Step 2: Count words
mock_provider.add_tool_call(
"count_words",
{"file_path": source_file},
"Word count: 10",
)
# Step 3: Append analysis
mock_provider.add_tool_call(
"append_to_file",
{
"file_path": source_file,
"content": "\n\n[Analysis] This document contains 10 words.",
},
f"Content appended to {source_file}",
)
mock_provider.add_text_response(
"I've analyzed the document and appended the word count analysis."
)
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a document analysis assistant.",
)
response = await agent.run(
f"Read {source_file}, count its words, and append the word count as an analysis note"
)
assert "analyzed" in response.lower()

View File

@@ -0,0 +1,226 @@
"""End-to-end scenario test for file operations with real API.
This test simulates a realistic scenario where an agent:
1. Reads a file
2. Explains its content
3. Creates a new file with analysis results
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, OpenAIProvider, tool
# =============================================================================
# Configuration from model_config.toml
# =============================================================================
# SiliconFlow API configuration (matches qwen35_397b in model_config.toml)
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
def get_siliconflow_provider() -> OpenAIProvider | None:
"""Create OpenAIProvider for SiliconFlow API.
Returns None if SILICONFLOW_API_KEY is not set.
"""
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
return None
return OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
)
# =============================================================================
# File Operation Tools
# =============================================================================
@tool()
async def read_file(file_path: str) -> str:
"""Read the content of a file.
Args:
file_path: Path to the file to read.
Returns:
The content of the file as a string.
Raises:
FileNotFoundError: If the file does not exist.
"""
with open(file_path) as f:
return f.read()
@tool()
async def write_file(file_path: str, content: str) -> str:
"""Write content to a file, creating it if it doesn't exist.
Args:
file_path: Path to the file to write.
content: Content to write to the file.
Returns:
Success message confirming the file was written.
"""
# Create parent directories if they don't exist
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(content)
return f"File successfully written to {file_path}"
@tool()
async def list_files(directory: str) -> str:
"""List all files in a directory.
Args:
directory: Path to the directory to list.
Returns:
A newline-separated list of file names in the directory.
"""
files = os.listdir(directory)
return "\n".join(files)
# =============================================================================
# Real API E2E Tests
# =============================================================================
@pytest.fixture
def real_provider():
"""Create a real SiliconFlow provider.
Skip tests if SILICONFLOW_API_KEY is not set.
"""
provider = get_siliconflow_provider()
if provider is None:
pytest.skip("SILICONFLOW_API_KEY not set, skipping real API tests")
return provider
@pytest.mark.scenario
@pytest.mark.expensive
class TestFileOperationsWithRealAPI:
"""End-to-end tests with real SiliconFlow API."""
@pytest.mark.asyncio
async def test_read_and_summarize(self, real_provider):
"""Test reading a file and creating a summary with real API."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a source file with meaningful content
source_file = os.path.join(tmpdir, "source.txt")
source_content = """AgentLite 项目概述
================
AgentLite 是一个轻量级的 Agent 组件库,主要特点:
- 异步优先设计
- OpenAI 兼容 API
- 工具系统 (支持 MCP)
- 流式响应支持
使用示例:
```python
from agentlite import Agent, OpenAIProvider
provider = OpenAIProvider(api_key="...", model="gpt-4")
agent = Agent(provider=provider)
response = await agent.run("Hello!")
```
"""
with open(source_file, "w") as f:
f.write(source_content)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=real_provider,
tools=tools,
system_prompt="你是一个文件分析助手。请使用工具来完成任务。",
)
# Run the agent to read, analyze, and write summary
output_file = os.path.join(tmpdir, "summary.txt")
response = await agent.run(
f"请读取 {source_file} 文件,分析其内容,并创建一个摘要文件保存到 {output_file}"
)
# Verify the agent responded
assert response, "Agent should return a response"
print(f"\n[Agent 响应]:\n{response}\n")
# Verify the output file was created
if os.path.exists(output_file):
with open(output_file) as f:
output_content = f.read()
print(f"\n[输出文件内容]:\n{output_content}\n")
assert len(output_content) > 0, "Output file should not be empty"
@pytest.mark.asyncio
async def test_list_files_and_combine(self, real_provider):
"""Test listing files, reading them, and creating combined report."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create multiple files
files = {
"sales.txt": "销售额增长了 20%",
"users.txt": "用户满意度达到 85%",
"bugs.txt": "修复了 15 个问题",
}
for name, content in files.items():
with open(os.path.join(tmpdir, name), "w") as f:
f.write(content)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=real_provider,
tools=tools,
system_prompt="你是一个数据分析助手。请使用工具来完成任务。",
)
# Run the agent
report_file = os.path.join(tmpdir, "report.txt")
response = await agent.run(
f"列出 {tmpdir} 目录中的所有文件,读取每个文件的内容,然后创建一份综合报告保存到 {report_file}"
)
# Verify the agent responded
assert response, "Agent should return a response"
print(f"\n[Agent 响应]:\n{response}\n")
# The agent should have created the report file
if os.path.exists(report_file):
with open(report_file) as f:
report_content = f.read()
print(f"\n[报告文件内容]:\n{report_content}\n")
@pytest.mark.asyncio
async def test_simple_conversation(self, real_provider):
"""Test basic conversation without tools."""
agent = Agent(
provider=real_provider,
system_prompt="你是一个有帮助的助手。请用中文回答。",
)
response = await agent.run("你好!请简单介绍一下你自己。")
assert response, "Agent should return a response"
print(f"\n[Agent 自我介绍]:\n{response}\n")
assert len(response) > 10, "Response should be meaningful"

View File

@@ -0,0 +1,521 @@
"""工具专项测试 - 文档提取和知识图谱工具
本模块测试基于数据基底的工具功能,包括:
1. 文档读取和解析工具
2. 实体提取工具
3. 知识图谱查询工具
4. 推理工具
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import pytest
import yaml
from agentlite import Agent, Message, TextPart, tool
def tool_output(result: Any) -> Any:
"""兼容旧式返回值和 ToolResult 返回值."""
return getattr(result, "output", result)
# =============================================================================
# 数据加载 fixtures
# =============================================================================
@pytest.fixture
def data_dir() -> Path:
"""返回测试数据目录路径."""
return Path(__file__).parent.parent / "data"
@pytest.fixture
def sample_article(data_dir: Path) -> str:
"""加载样例文章."""
return (data_dir / "documents" / "sample_article.md").read_text(encoding="utf-8")
@pytest.fixture
def technical_spec(data_dir: Path) -> str:
"""加载技术规范文档."""
return (data_dir / "documents" / "technical_spec.md").read_text(encoding="utf-8")
@pytest.fixture
def meeting_notes(data_dir: Path) -> str:
"""加载会议记录."""
return (data_dir / "documents" / "meeting_notes.txt").read_text(encoding="utf-8")
@pytest.fixture
def knowledge_graph_entities(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱实体数据."""
with open(data_dir / "knowledge_base" / "entities.json") as f:
return json.load(f)
@pytest.fixture
def knowledge_graph_relations(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱关系数据."""
with open(data_dir / "knowledge_base" / "relations.json") as f:
return json.load(f)
@pytest.fixture
def graph_queries(data_dir: Path) -> list[dict[str, Any]]:
"""加载图谱查询测试用例."""
with open(data_dir / "knowledge_base" / "graph_queries.yaml") as f:
data = yaml.safe_load(f)
return data.get("queries", [])
# =============================================================================
# 知识图谱工具实现
# =============================================================================
class KnowledgeGraph:
"""知识图谱内存存储."""
def __init__(self, entities: list[dict], relations: list[dict]):
self._entities = {e["id"]: e for e in entities}
self._relations = relations
self._index_by_type: dict[str, list[str]] = {}
self._index_by_name: dict[str, str] = {}
# 构建索引
for entity_id, entity in self._entities.items():
entity_type = entity.get("type", "Unknown")
if entity_type not in self._index_by_type:
self._index_by_type[entity_type] = []
self._index_by_type[entity_type].append(entity_id)
name = entity.get("name", "")
if name:
self._index_by_name[name] = entity_id
def get_entity(self, entity_id: str) -> dict | None:
"""获取实体."""
return self._entities.get(entity_id)
def get_entity_by_name(self, name: str) -> dict | None:
"""通过名称获取实体."""
entity_id = self._index_by_name.get(name)
if entity_id:
return self._entities.get(entity_id)
return None
def get_entities_by_type(self, entity_type: str) -> list[dict]:
"""获取特定类型的所有实体."""
entity_ids = self._index_by_type.get(entity_type, [])
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
def get_relations(
self, from_id: str | None = None, to_id: str | None = None, relation_type: str | None = None
) -> list[dict]:
"""获取关系."""
results = []
for rel in self._relations:
if from_id and rel.get("from") != from_id:
continue
if to_id and rel.get("to") != to_id:
continue
if relation_type and rel.get("type") != relation_type:
continue
results.append(rel)
return results
def get_neighbors(self, entity_id: str, relation_type: str | None = None) -> list[dict]:
"""获取邻居实体."""
relations = self.get_relations(from_id=entity_id, relation_type=relation_type)
neighbors = []
for rel in relations:
target_id = rel.get("to")
if target_id and target_id in self._entities:
neighbors.append({"entity": self._entities[target_id], "relation": rel})
return neighbors
def find_path(self, start_id: str, end_id: str, max_depth: int = 3) -> list[list[str]] | None:
"""查找两个实体之间的路径."""
if start_id == end_id:
return [[start_id]]
if max_depth <= 0:
return None
# BFS
from collections import deque
queue = deque([(start_id, [start_id])])
visited = {start_id}
all_paths = []
while queue:
current_id, path = queue.popleft()
if len(path) > max_depth + 1:
continue
relations = self.get_relations(from_id=current_id)
for rel in relations:
next_id = rel.get("to")
if not next_id:
continue
new_path = path + [next_id]
if next_id == end_id:
all_paths.append(new_path)
elif next_id not in visited and len(new_path) <= max_depth:
visited.add(next_id)
queue.append((next_id, new_path))
return all_paths if all_paths else None
@pytest.fixture
def knowledge_graph(knowledge_graph_entities, knowledge_graph_relations) -> KnowledgeGraph:
"""创建知识图谱实例."""
return KnowledgeGraph(
entities=knowledge_graph_entities.get("entities", []),
relations=knowledge_graph_relations.get("relations", []),
)
# =============================================================================
# 工具定义
# =============================================================================
@tool()
async def read_document(file_path: str) -> str:
"""读取文档内容.
Args:
file_path: 文档路径
Returns:
文档内容
"""
path = Path(file_path)
if not path.exists():
return f"Error: File not found: {file_path}"
try:
return path.read_text(encoding="utf-8")
except Exception as e:
return f"Error reading file: {e}"
@tool()
async def extract_entities(text: str) -> str:
"""从文本中提取实体.
Args:
text: 输入文本
Returns:
JSON 格式的实体列表
"""
# 简化的实体提取 - 实际应使用 NLP 模型
import re
entities = []
# 提取人名(简单的中文姓名匹配)
person_pattern = r"[\u4e00-\u9fa5]{2,4}"
potential_names = re.findall(person_pattern, text)
common_names = ["张三", "李四", "王五", "赵六", "李飞飞", "吴恩达", "Yann LeCun"]
for name in potential_names:
if name in common_names or len(name) == 3:
entities.append({"type": "Person", "name": name})
# 提取公司/组织名
org_pattern = r"(TechCorp|OpenAI|GitHub|Google)"
orgs = re.findall(org_pattern, text)
for org in set(orgs):
entities.append({"type": "Organization", "name": org})
# 提取技术术语
tech_pattern = r"(Python|TensorFlow|PyTorch|GPT-4|AI|LLM)"
techs = re.findall(tech_pattern, text)
for tech in set(techs):
entities.append({"type": "Technology", "name": tech})
return json.dumps(entities, ensure_ascii=False)
@tool()
async def query_knowledge_graph(query_type: str, params: str) -> str:
"""查询知识图谱.
Args:
query_type: 查询类型 (person_relations, company_employees, technology_users, etc.)
params: JSON 格式的查询参数
Returns:
查询结果
"""
# 这里使用全局的 kg 实例,实际应在 Agent 初始化时注入
try:
params_dict = json.loads(params)
except json.JSONDecodeError:
return json.dumps({"error": "Invalid JSON params"})
# 简化实现 - 实际应基于知识图谱查询
result = {"query_type": query_type, "params": params_dict, "results": []}
return json.dumps(result, ensure_ascii=False)
@tool()
async def reason_about_path(start_entity: str, end_entity: str) -> str:
"""推理两个实体之间的关系路径.
Args:
start_entity: 起始实体名称
end_entity: 目标实体名称
Returns:
推理结果
"""
return json.dumps(
{
"start": start_entity,
"end": end_entity,
"reasoning": f"分析 {start_entity}{end_entity} 的关系链...",
"path": [],
},
ensure_ascii=False,
)
# =============================================================================
# 测试用例
# =============================================================================
@pytest.mark.tools
class TestDocumentTools:
"""文档工具测试."""
@pytest.mark.asyncio
async def test_read_document(self, data_dir: Path, sample_article: str):
"""测试文档读取工具."""
result = tool_output(await read_document(str(data_dir / "documents" / "sample_article.md")))
assert "人工智能" in result
assert "GitHub Copilot" in result
assert "张三" in result
@pytest.mark.asyncio
async def test_read_document_not_found(self):
"""测试读取不存在的文档."""
result = tool_output(await read_document("/nonexistent/file.md"))
assert "Error" in result
assert "not found" in result.lower()
@pytest.mark.asyncio
async def test_extract_entities_from_article(self, sample_article: str):
"""测试从文章中提取实体."""
result = tool_output(await extract_entities(sample_article))
entities = json.loads(result)
# 验证提取到实体
assert len(entities) > 0
# 验证实体类型
entity_names = [e["name"] for e in entities]
assert "张三" in entity_names
assert "TechCorp" in entity_names or "OpenAI" in entity_names
@pytest.mark.tools
class TestKnowledgeGraphTools:
"""知识图谱工具测试."""
def test_knowledge_graph_initialization(self, knowledge_graph: KnowledgeGraph):
"""测试知识图谱初始化."""
# 验证实体数量
entity = knowledge_graph.get_entity_by_name("张三")
assert entity is not None
assert entity["type"] == "Person"
# 验证公司实体
company = knowledge_graph.get_entity_by_name("TechCorp")
assert company is not None
assert company["type"] == "Company"
def test_get_entities_by_type(self, knowledge_graph: KnowledgeGraph):
"""测试按类型获取实体."""
persons = knowledge_graph.get_entities_by_type("Person")
assert len(persons) >= 3 # 张三、李四、李飞飞
technologies = knowledge_graph.get_entities_by_type("Technology")
assert len(technologies) >= 2 # Python、OpenAI API
def test_get_relations(self, knowledge_graph: KnowledgeGraph):
"""测试获取关系."""
# 获取张三的所有关系
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
relations = knowledge_graph.get_relations(from_id=zhangsan["id"])
assert len(relations) >= 2 # works_for, uses
# 验证关系类型
relation_types = [r["type"] for r in relations]
assert "works_for" in relation_types
assert "uses" in relation_types
def test_get_neighbors(self, knowledge_graph: KnowledgeGraph):
"""测试获取邻居节点."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
neighbors = knowledge_graph.get_neighbors(zhangsan["id"])
assert len(neighbors) >= 2
# 验证邻居包含 TechCorp
neighbor_names = [n["entity"]["name"] for n in neighbors]
assert "TechCorp" in neighbor_names
def test_find_path(self, knowledge_graph: KnowledgeGraph):
"""测试查找路径."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
techcorp = knowledge_graph.get_entity_by_name("TechCorp")
assert zhangsan is not None
assert techcorp is not None
paths = knowledge_graph.find_path(zhangsan["id"], techcorp["id"])
assert paths is not None
assert len(paths) > 0
# 验证路径长度
first_path = paths[0]
assert len(first_path) == 2 # 张三 -> TechCorp
@pytest.mark.asyncio
async def test_query_knowledge_graph(self):
"""测试知识图谱查询工具."""
params = json.dumps({"entity_name": "张三"})
result = tool_output(await query_knowledge_graph("person_relations", params))
data = json.loads(result)
assert data["query_type"] == "person_relations"
assert "params" in data
@pytest.mark.asyncio
async def test_reason_about_path(self):
"""测试路径推理工具."""
result = tool_output(await reason_about_path("张三", "OpenAI"))
data = json.loads(result)
assert data["start"] == "张三"
assert data["end"] == "OpenAI"
assert "reasoning" in data
@pytest.mark.tools
class TestDataIntegrity:
"""数据完整性测试."""
def test_entities_json_valid(self, knowledge_graph_entities: dict):
"""验证实体 JSON 格式正确."""
assert "entities" in knowledge_graph_entities
entities = knowledge_graph_entities["entities"]
assert len(entities) > 0
# 验证每个实体都有必需的字段
for entity in entities:
assert "id" in entity
assert "type" in entity
assert "name" in entity
def test_relations_json_valid(
self, knowledge_graph_relations: dict, knowledge_graph_entities: dict
):
"""验证关系 JSON 格式正确且引用的实体存在."""
assert "relations" in knowledge_graph_relations
relations = knowledge_graph_relations["relations"]
entity_ids = {e["id"] for e in knowledge_graph_entities["entities"]}
for relation in relations:
assert "from" in relation
assert "to" in relation
assert "type" in relation
# 验证引用的实体存在
assert relation["from"] in entity_ids, f"Entity {relation['from']} not found"
assert relation["to"] in entity_ids, f"Entity {relation['to']} not found"
def test_graph_queries_yaml_valid(self, graph_queries: list):
"""验证查询 YAML 格式正确."""
assert len(graph_queries) > 0
for query in graph_queries:
assert "id" in query
assert "description" in query
assert "query" in query
assert "expected_results" in query
def test_documents_exist(self, data_dir: Path):
"""验证测试文档存在且非空."""
docs_dir = data_dir / "documents"
sample_article = docs_dir / "sample_article.md"
assert sample_article.exists()
assert sample_article.stat().st_size > 0
tech_spec = docs_dir / "technical_spec.md"
assert tech_spec.exists()
assert tech_spec.stat().st_size > 0
meeting_notes = docs_dir / "meeting_notes.txt"
assert meeting_notes.exists()
assert meeting_notes.stat().st_size > 0
@pytest.mark.tools
class TestAgentWithTools:
"""Agent 集成工具测试."""
@pytest.mark.asyncio
async def test_agent_with_document_tools(self, mock_provider, data_dir: Path):
"""测试带有文档工具的 Agent."""
mock_provider.add_text_response("我已经读取了文档")
agent = Agent(
provider=mock_provider,
tools=[read_document],
system_prompt="你是一个文档助手,可以读取和分析文档。",
)
response = await agent.run(f"请读取文档 {data_dir / 'documents' / 'sample_article.md'}")
assert "文档" in response or "读取" in response
@pytest.mark.asyncio
async def test_agent_with_kg_tools(self, mock_provider):
"""测试带有知识图谱工具的 Agent."""
mock_provider.add_text_response("张三在 TechCorp 工作")
agent = Agent(
provider=mock_provider,
tools=[query_knowledge_graph, reason_about_path],
system_prompt="你是一个知识图谱助手,可以查询实体关系。",
)
response = await agent.run("张三在哪里工作?")
assert response is not None
assert len(response) > 0

View File

View File

@@ -0,0 +1,330 @@
"""Unit tests for configuration models.
This module tests all Pydantic configuration models including
ProviderConfig, ModelConfig, ToolConfig, and AgentConfig.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from agentlite import ProviderConfig, ModelConfig, AgentConfig
class TestProviderConfig:
"""Tests for ProviderConfig."""
def test_provider_config_valid(self):
"""Test valid ProviderConfig creation."""
config = ProviderConfig(
type="openai",
base_url="https://api.openai.com/v1",
api_key="sk-test123",
)
assert config.type == "openai"
assert config.base_url == "https://api.openai.com/v1"
assert config.api_key.get_secret_value() == "sk-test123"
def test_provider_config_default_type(self):
"""Test ProviderConfig with default type."""
config = ProviderConfig(
base_url="https://api.openai.com/v1",
api_key="sk-test",
)
assert config.type == "openai"
def test_provider_config_default_url(self):
"""Test ProviderConfig with default base_url."""
config = ProviderConfig(
api_key="sk-test",
)
assert config.base_url == "https://api.openai.com/v1"
def test_provider_config_invalid_url_http(self):
"""Test ProviderConfig with invalid URL scheme."""
with pytest.raises(ValidationError) as exc_info:
ProviderConfig(
type="openai",
base_url="ftp://invalid.com",
api_key="sk-test",
)
assert "base_url must start with http:// or https://" in str(exc_info.value)
def test_provider_config_invalid_url_no_scheme(self):
"""Test ProviderConfig with URL without scheme."""
with pytest.raises(ValidationError):
ProviderConfig(
base_url="api.openai.com/v1",
api_key="sk-test",
)
def test_provider_config_custom_headers(self):
"""Test ProviderConfig with custom headers."""
config = ProviderConfig(
api_key="sk-test",
headers={"X-Custom": "value"},
)
assert config.headers == {"X-Custom": "value"}
def test_provider_config_default_headers(self):
"""Test ProviderConfig default headers."""
config = ProviderConfig(api_key="sk-test")
assert config.headers == {}
def test_provider_config_timeout(self):
"""Test ProviderConfig timeout."""
config = ProviderConfig(
api_key="sk-test",
timeout=30.0,
)
assert config.timeout == 30.0
def test_provider_config_default_timeout(self):
"""Test ProviderConfig default timeout."""
config = ProviderConfig(api_key="sk-test")
assert config.timeout == 60.0
def test_provider_config_api_key_is_secret_str(self):
"""Test that api_key is stored as SecretStr."""
config = ProviderConfig(api_key="sk-secret")
# SecretStr should not expose value in repr/str
assert "sk-secret" not in str(config.api_key)
# But can get value explicitly
assert config.api_key.get_secret_value() == "sk-secret"
class TestModelConfig:
"""Tests for ModelConfig."""
def test_model_config_valid(self):
"""Test valid ModelConfig creation."""
config = ModelConfig(
provider="openai",
model="gpt-4",
)
assert config.provider == "openai"
assert config.model == "gpt-4"
def test_model_config_with_all_fields(self):
"""Test ModelConfig with all optional fields."""
config = ModelConfig(
provider="openai",
model="gpt-4",
max_tokens=1000,
temperature=0.7,
top_p=0.9,
capabilities={"streaming", "tool_calling"},
)
assert config.max_tokens == 1000
assert config.temperature == 0.7
assert config.top_p == 0.9
assert config.capabilities == {"streaming", "tool_calling"}
def test_model_config_empty_provider(self):
"""Test ModelConfig with empty provider."""
with pytest.raises(ValidationError) as exc_info:
ModelConfig(
provider="",
model="gpt-4",
)
assert "provider must not be empty" in str(exc_info.value)
def test_model_config_temperature_bounds(self):
"""Test ModelConfig temperature validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=0.0)
assert config.temperature == 0.0
# Valid: 2.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=2.0)
assert config.temperature == 2.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=-0.1)
# Invalid: > 2
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=2.1)
def test_model_config_top_p_bounds(self):
"""Test ModelConfig top_p validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=0.0)
assert config.top_p == 0.0
# Valid: 1.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=1.0)
assert config.top_p == 1.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=-0.1)
# Invalid: > 1
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=1.1)
def test_model_config_default_capabilities(self):
"""Test ModelConfig default capabilities."""
config = ModelConfig(provider="openai", model="gpt-4")
assert config.capabilities == set()
class TestAgentConfig:
"""Tests for AgentConfig."""
def test_agent_config_minimal(self):
"""Test AgentConfig with minimal required fields."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
assert config.name == "agent"
assert config.system_prompt == "You are a helpful assistant."
assert config.default_model == "default"
def test_agent_config_full(self):
"""Test AgentConfig with all fields."""
config = AgentConfig(
name="my_agent",
system_prompt="Custom system prompt",
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
max_history=50,
)
assert config.name == "my_agent"
assert config.system_prompt == "Custom system prompt"
assert config.default_model == "gpt4"
assert config.max_history == 50
def test_agent_config_missing_default_model(self):
"""Test AgentConfig with non-existent default_model."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="nonexistent",
)
assert "not found in models" in str(exc_info.value)
def test_agent_config_unknown_provider(self):
"""Test AgentConfig with model referencing unknown provider."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="unknown", model="gpt-4")},
)
assert "unknown provider" in str(exc_info.value)
def test_agent_config_get_provider_config(self):
"""Test get_provider_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config("gpt4")
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_default(self):
"""Test get_provider_config with default model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config()
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_not_found(self):
"""Test get_provider_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_provider_config("nonexistent")
def test_agent_config_get_model_config(self):
"""Test get_model_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config("gpt4")
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_default(self):
"""Test get_model_config with default."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config()
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_not_found(self):
"""Test get_model_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_model_config("nonexistent")
def test_agent_config_multiple_providers(self):
"""Test AgentConfig with multiple providers."""
config = AgentConfig(
providers={
"openai": ProviderConfig(api_key="sk-openai"),
"anthropic": ProviderConfig(
type="anthropic",
base_url="https://api.anthropic.com/v1",
api_key="sk-anthropic",
),
},
models={
"default": ModelConfig(provider="openai", model="gpt-4"),
"claude": ModelConfig(provider="anthropic", model="claude-3"),
},
)
assert len(config.providers) == 2
assert len(config.models) == 2
def test_agent_config_max_history_validation(self):
"""Test max_history validation."""
# Valid: min=1
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=1,
)
assert config.max_history == 1
# Invalid: 0
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=0,
)
# Invalid: negative
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=-1,
)

View File

@@ -0,0 +1,297 @@
"""Unit tests for message types.
This module tests all message-related types including ContentPart,
Message, ToolCall, and their various subclasses.
"""
from __future__ import annotations
import pytest
from agentlite import (
ContentPart,
Message,
TextPart,
ImageURLPart,
AudioURLPart,
ToolCall,
ToolCallPart,
)
class TestContentPart:
"""Tests for ContentPart base class and registry."""
def test_content_part_registry_auto_registers_subclasses(self):
"""Test that ContentPart subclasses are auto-registered."""
# All defined subclasses should be in registry
assert "text" in ContentPart._ContentPart__content_part_registry
assert "image_url" in ContentPart._ContentPart__content_part_registry
assert "audio_url" in ContentPart._ContentPart__content_part_registry
def test_text_part_creation(self):
"""Test basic TextPart creation."""
part = TextPart(text="Hello, world!")
assert part.type == "text"
assert part.text == "Hello, world!"
def test_text_part_model_dump(self):
"""Test TextPart serialization."""
part = TextPart(text="Hello")
dumped = part.model_dump()
assert dumped == {"type": "text", "text": "Hello"}
def test_text_part_merge_success(self):
"""Test successful text merge during streaming."""
part1 = TextPart(text="Hello ")
part2 = TextPart(text="world!")
result = part1.merge_in_place(part2)
assert result is True
assert part1.text == "Hello world!"
def test_text_part_merge_failure(self):
"""Test merge failure with incompatible types."""
text_part = TextPart(text="Hello")
# Try to merge with non-TextPart
result = text_part.merge_in_place("not a part")
assert result is False
assert text_part.text == "Hello" # Unchanged
class TestImageURLPart:
"""Tests for ImageURLPart."""
def test_image_url_part_creation(self):
"""Test ImageURLPart creation."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.type == "image_url"
assert part.image_url.url == "https://example.com/image.png"
def test_image_url_part_with_detail(self):
"""Test ImageURLPart with detail parameter."""
part = ImageURLPart(
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
)
assert part.image_url.detail == "high"
def test_image_url_part_default_detail(self):
"""Test ImageURLPart default detail is None."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.image_url.detail is None
class TestAudioURLPart:
"""Tests for AudioURLPart."""
def test_audio_url_part_creation(self):
"""Test AudioURLPart creation."""
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
assert part.type == "audio_url"
assert part.audio_url.url == "https://example.com/audio.mp3"
class TestToolCall:
"""Tests for ToolCall."""
def test_tool_call_creation(self):
"""Test ToolCall creation."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
assert call.type == "function"
assert call.id == "call_123"
assert call.function.name == "add"
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_with_part(self):
"""Test ToolCall merging with ToolCallPart."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
)
part = ToolCallPart(arguments_part=', "b": 2}')
result = call.merge_in_place(part)
assert result is True
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_failure(self):
"""Test ToolCall merge failure with incompatible types."""
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
result = call.merge_in_place("not a part")
assert result is False
class TestToolCallPart:
"""Tests for ToolCallPart."""
def test_tool_call_part_creation(self):
"""Test ToolCallPart creation."""
part = ToolCallPart(arguments_part='{"a": 1}')
assert part.arguments_part == '{"a": 1}'
def test_tool_call_part_none(self):
"""Test ToolCallPart with None arguments."""
part = ToolCallPart(arguments_part=None)
assert part.arguments_part is None
def test_tool_call_part_merge(self):
"""Test ToolCallPart merging."""
part1 = ToolCallPart(arguments_part='{"a":')
part2 = ToolCallPart(arguments_part=" 1}")
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
def test_tool_call_part_merge_none(self):
"""Test ToolCallPart merge when self is None."""
part1 = ToolCallPart(arguments_part=None)
part2 = ToolCallPart(arguments_part='{"a": 1}')
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
class TestMessage:
"""Tests for Message."""
def test_message_string_content_coercion(self):
"""Test that string content is coerced to TextPart."""
msg = Message(role="user", content="Hello!")
assert len(msg.content) == 1
assert isinstance(msg.content[0], TextPart)
assert msg.content[0].text == "Hello!"
def test_message_part_content(self):
"""Test Message with ContentPart content."""
part = TextPart(text="Hello!")
msg = Message(role="user", content=part)
assert len(msg.content) == 1
assert msg.content[0].text == "Hello!"
def test_message_list_content(self):
"""Test Message with list of ContentParts."""
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
msg = Message(role="user", content=parts)
assert len(msg.content) == 2
def test_message_extract_text(self):
"""Test text extraction from message."""
msg = Message(role="user", content="Hello world!")
assert msg.extract_text() == "Hello world!"
def test_message_extract_text_with_separator(self):
"""Test text extraction with custom separator."""
parts = [TextPart(text="Hello"), TextPart(text="world!")]
msg = Message(role="user", content=parts)
assert msg.extract_text(sep=" ") == "Hello world!"
assert msg.extract_text(sep="-") == "Hello-world!"
def test_message_has_tool_calls_false(self):
"""Test has_tool_calls returns False when no tool calls."""
msg = Message(role="assistant", content="Hello!")
assert msg.has_tool_calls() is False
def test_message_has_tool_calls_true(self):
"""Test has_tool_calls returns True when tool calls present."""
tool_call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
)
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
assert msg.has_tool_calls() is True
def test_message_has_tool_calls_empty_list(self):
"""Test has_tool_calls with empty tool_calls list."""
msg = Message(role="assistant", content="Hello!", tool_calls=[])
assert msg.has_tool_calls() is False
def test_message_tool_response(self):
"""Test message with tool response."""
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
assert msg.role == "tool"
assert msg.tool_call_id == "call_123"
def test_message_serialization(self):
"""Test Message serialization with model_dump."""
msg = Message(role="user", content="Hello!")
dumped = msg.model_dump()
assert dumped["role"] == "user"
assert "content" in dumped
def test_message_all_roles(self):
"""Test Message creation with all valid roles."""
for role in ["system", "user", "assistant", "tool"]:
msg = Message(role=role, content="Test")
assert msg.role == role
class TestPolymorphicContentPart:
"""Tests for polymorphic ContentPart validation."""
def test_polymorphic_validation_text(self):
"""Test that text type validates to TextPart."""
data = {"type": "text", "text": "Hello"}
part = ContentPart.model_validate(data)
assert isinstance(part, TextPart)
assert part.text == "Hello"
def test_polymorphic_validation_image(self):
"""Test that image_url type validates to ImageURLPart."""
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
part = ContentPart.model_validate(data)
assert isinstance(part, ImageURLPart)
assert part.image_url.url == "https://example.com/image.png"
def test_polymorphic_validation_unknown_type(self):
"""Test validation with unknown type raises error."""
data = {"type": "unknown_type", "content": "test"}
with pytest.raises(ValueError, match="Unknown content part type"):
ContentPart.model_validate(data)
def test_polymorphic_validation_no_type(self):
"""Test validation without type raises error."""
data = {"content": "test"}
with pytest.raises(ValueError):
ContentPart.model_validate(data)
class TestMessageEdgeCases:
"""Tests for edge cases in Message handling."""
def test_empty_string_content(self):
"""Test Message with empty string content."""
msg = Message(role="user", content="")
assert msg.content[0].text == ""
def test_message_with_name(self):
"""Test Message with name field."""
msg = Message(role="user", content="Hello", name="user1")
assert msg.name == "user1"
def test_message_history_isolation(self):
"""Test that history modifications don't affect original."""
msg = Message(role="user", content="Hello")
# Modify the content list
msg.content.append(TextPart(text="Extra"))
# Original should be modified (it's the same object)
assert len(msg.content) == 2

View File

@@ -0,0 +1,168 @@
"""Unit tests for provider protocol and exceptions.
This module tests the ChatProvider protocol, StreamedMessage protocol,
and all exception types.
"""
from __future__ import annotations
import pytest
from agentlite.provider import (
TokenUsage,
ChatProviderError,
APIConnectionError,
APITimeoutError,
APIStatusError,
APIEmptyResponseError,
ChatProvider,
StreamedMessage,
)
class TestTokenUsage:
"""Tests for TokenUsage."""
def test_token_usage_creation(self):
"""Test TokenUsage creation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.input_tokens == 100
assert usage.output_tokens == 50
assert usage.cached_tokens == 0 # Default
def test_token_usage_with_cached(self):
"""Test TokenUsage with cached tokens."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
assert usage.cached_tokens == 20
def test_token_usage_total(self):
"""Test total token calculation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.total == 150
def test_token_usage_total_with_cached(self):
"""Test total with cached tokens (not included in total)."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
# Total is input + output, cached is tracked separately
assert usage.total == 150
class TestChatProviderError:
"""Tests for ChatProviderError hierarchy."""
def test_base_error_creation(self):
"""Test base ChatProviderError creation."""
error = ChatProviderError("Something went wrong")
assert error.message == "Something went wrong"
assert str(error) == "Something went wrong"
def test_api_connection_error(self):
"""Test APIConnectionError creation."""
error = APIConnectionError("Connection failed")
assert isinstance(error, ChatProviderError)
assert error.message == "Connection failed"
def test_api_timeout_error(self):
"""Test APITimeoutError creation."""
error = APITimeoutError("Request timed out")
assert isinstance(error, ChatProviderError)
assert error.message == "Request timed out"
def test_api_status_error(self):
"""Test APIStatusError creation."""
error = APIStatusError(429, "Rate limit exceeded")
assert isinstance(error, ChatProviderError)
assert error.status_code == 429
assert error.message == "Rate limit exceeded"
def test_api_status_error_different_codes(self):
"""Test APIStatusError with different status codes."""
codes = [400, 401, 403, 404, 429, 500, 502, 503]
for code in codes:
error = APIStatusError(code, f"Error {code}")
assert error.status_code == code
def test_api_empty_response_error(self):
"""Test APIEmptyResponseError creation."""
error = APIEmptyResponseError("Empty response from API")
assert isinstance(error, ChatProviderError)
assert error.message == "Empty response from API"
def test_exception_hierarchy(self):
"""Test that all exceptions inherit from ChatProviderError."""
errors = [
APIConnectionError("test"),
APITimeoutError("test"),
APIStatusError(500, "test"),
APIEmptyResponseError("test"),
]
for error in errors:
assert isinstance(error, ChatProviderError)
class TestChatProviderProtocol:
"""Tests for ChatProvider protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that ChatProvider is runtime checkable."""
# ChatProvider should have @runtime_checkable
from typing import runtime_checkable
assert hasattr(ChatProvider, "__protocol_attrs__")
def test_mock_provider_implements_protocol(self, mock_provider):
"""Test that MockProvider implements ChatProvider."""
assert isinstance(mock_provider, ChatProvider)
def test_mock_provider_has_model_name(self, mock_provider):
"""Test that mock provider has model_name property."""
assert hasattr(mock_provider, "model_name")
assert isinstance(mock_provider.model_name, str)
def test_mock_provider_has_generate_method(self, mock_provider):
"""Test that mock provider has generate method."""
assert hasattr(mock_provider, "generate")
assert callable(mock_provider.generate)
class TestStreamedMessageProtocol:
"""Tests for StreamedMessage protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that StreamedMessage is runtime checkable."""
assert hasattr(StreamedMessage, "__protocol_attrs__")
def test_mock_streamed_message_implements_protocol(self):
"""Test that MockStreamedMessage implements StreamedMessage."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert isinstance(stream, StreamedMessage)
def test_streamed_message_has_id_property(self):
"""Test that streamed message has id property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "id")
assert stream.id == "mock-msg-123"
def test_streamed_message_has_usage_property(self):
"""Test that streamed message has usage property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "usage")
assert stream.usage is not None
assert isinstance(stream.usage, TokenUsage)
def test_streamed_message_is_async_iterable(self):
"""Test that streamed message is async iterable."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "__aiter__")

View File

@@ -0,0 +1,210 @@
"""Unit tests for tool decorator and CallableTool.
This module tests the @tool() decorator and related tool functionality.
"""
from __future__ import annotations
import pytest
from agentlite.tool import tool, CallableTool, ToolOk, ToolError
class TestToolDecorator:
"""Tests for the @tool() decorator."""
def test_tool_decorator_basic(self):
"""Test basic tool decorator functionality."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert isinstance(add, CallableTool)
assert add.name == "add"
assert add.description == "Add two numbers."
assert add.parameters["type"] == "object"
assert "a" in add.parameters["properties"]
assert "b" in add.parameters["properties"]
assert add.parameters["properties"]["a"]["type"] == "number"
assert add.parameters["properties"]["b"]["type"] == "number"
assert add.parameters["required"] == ["a", "b"]
def test_tool_decorator_with_default_params(self):
"""Test tool decorator with default parameters."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
assert greet.name == "greet"
assert "name" in greet.parameters["required"]
assert "greeting" not in greet.parameters["required"]
def test_tool_decorator_custom_name(self):
"""Test tool decorator with custom name."""
@tool(name="custom_add")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.name == "custom_add"
def test_tool_decorator_custom_description(self):
"""Test tool decorator with custom description."""
@tool(description="Custom description")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.description == "Custom description"
def test_tool_decorator_no_docstring(self):
"""Test tool decorator with no docstring."""
@tool()
async def no_doc(a: float) -> float:
return a
assert no_doc.description == "No description provided"
def test_tool_decorator_param_types(self):
"""Test tool decorator with various parameter types."""
@tool()
async def multi_types(
s: str,
i: int,
f: float,
b: bool,
) -> dict:
"""Multiple types."""
return {"s": s, "i": i, "f": f, "b": b}
props = multi_types.parameters["properties"]
assert props["s"]["type"] == "string"
assert props["i"]["type"] == "integer"
assert props["f"]["type"] == "number"
assert props["b"]["type"] == "boolean"
def test_tool_decorator_no_type_hints(self):
"""Test tool decorator with no type hints."""
@tool()
async def no_types(param) -> str:
"""No type hints."""
return str(param)
assert no_types.parameters["properties"]["param"]["type"] == "string"
class TestToolDecoratorExecution:
"""Tests for tool decorator execution."""
@pytest.mark.asyncio
async def test_tool_execution_success(self):
"""Test successful tool execution."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
result = await add(1.0, 2.0)
assert isinstance(result, ToolOk)
assert result.output == "3.0"
@pytest.mark.asyncio
async def test_tool_execution_error(self):
"""Test tool execution with error."""
@tool()
async def divide(a: float, b: float) -> float:
"""Divide two numbers."""
return a / b
result = await divide(1.0, 0.0)
assert isinstance(result, ToolError)
assert "division by zero" in result.message
@pytest.mark.asyncio
async def test_tool_execution_with_kwargs(self):
"""Test tool execution with keyword arguments."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
result = await greet(name="World", greeting="Hi")
assert isinstance(result, ToolOk)
assert result.output == "Hi, World!"
class TestToolDecoratorMemorixBug:
"""Tests for the specific bug reported by Memorix project."""
def test_tool_decorator_memorix_case(self):
"""Test the exact case from Memorix bug report.
This test verifies that the @tool() decorator works correctly
with async functions that have string and float parameters.
"""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
assert isinstance(add_memory, CallableTool)
assert add_memory.name == "add_memory"
assert add_memory.description == "存储记忆"
# Check parameters schema
params = add_memory.parameters
assert params["type"] == "object"
assert "content" in params["properties"]
assert "importance" in params["properties"]
assert params["properties"]["content"]["type"] == "string"
assert params["properties"]["importance"]["type"] == "number"
# content is required (no default), importance is optional
assert "content" in params["required"]
assert "importance" not in params["required"]
@pytest.mark.asyncio
async def test_tool_decorator_memorix_execution(self):
"""Test execution of the Memorix case."""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok", "content": content, "importance": importance}
result = await add_memory("test content", 0.8)
assert isinstance(result, ToolOk)
assert "ok" in result.output
def test_tool_decorator_can_be_used_in_agent(self):
"""Test that decorated tools can be used with Agent.
This is an integration-style test to ensure the decorated tool
has all required attributes for Agent usage.
"""
from agentlite import Agent, OpenAIProvider
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
# Verify the tool has the base property required by Agent
assert hasattr(add_memory, "base")
base_tool = add_memory.base
assert base_tool.name == "add_memory"
assert base_tool.description == "存储记忆"
assert base_tool.parameters == add_memory.parameters

98
agentlite/tests/utils.py Normal file
View File

@@ -0,0 +1,98 @@
"""Test utilities and helpers for AgentLite tests.
This module provides utility functions and helpers used across test modules.
"""
from __future__ import annotations
import asyncio
from typing import Any, TypeVar
T = TypeVar("T")
async def run_async(coro: asyncio.Coroutine[Any, Any, T]) -> T:
"""Run an async coroutine and return the result.
This is a helper for tests that need to run async code synchronously.
Args:
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
return await coro
def run_sync(coro: asyncio.Coroutine[Any, Any, T]) -> T:
"""Run an async coroutine synchronously.
Args:
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
return asyncio.run(coro)
async def collect_stream(stream) -> list[Any]:
"""Collect all items from an async stream into a list.
Args:
stream: The async stream to collect from.
Returns:
List of all items from the stream.
"""
items = []
async for item in stream:
items.append(item)
return items
async def collect_stream_text(stream) -> str:
"""Collect all text from an async text stream.
Args:
stream: The async stream to collect from.
Returns:
Concatenated text from all items.
"""
from agentlite import TextPart
text_parts = []
async for item in stream:
if isinstance(item, TextPart):
text_parts.append(item.text)
elif isinstance(item, str):
text_parts.append(item)
return "".join(text_parts)
def create_tool_schema(
name: str,
description: str,
properties: dict[str, Any],
required: list[str] | None = None,
) -> dict[str, Any]:
"""Create a JSON schema for a tool.
Args:
name: Tool name.
description: Tool description.
properties: JSON schema properties.
required: List of required property names.
Returns:
JSON schema for the tool.
"""
schema = {
"type": "object",
"properties": properties,
}
if required:
schema["required"] = required
return schema