Merge pull request #1578 from tcmofashi/rdev

feat: add a subagent frame
This commit is contained in:
tcmofashi
2026-04-03 23:26:06 +08:00
committed by GitHub
82 changed files with 13046 additions and 18 deletions

31
agentlite/CHANGELOG.md Normal file
View File

@@ -0,0 +1,31 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.0] - 2025-01-30
### Added
- Initial release of AgentLite
- Core Agent class with streaming and non-streaming interfaces
- OpenAI-compatible provider implementation
- Tool system with decorator and class-based tools
- MCP client for loading tools from MCP servers
- Pydantic-based configuration system
- Multi-agent support
- Full type hints and async/await throughout
- Comprehensive documentation and examples
### Features
- **Agent**: Main agent class with tool calling loop
- **OpenAIProvider**: OpenAI API integration with streaming support
- **MCPClient**: MCP server integration for external tools
- **Tool System**: Decorator (`@tool`) and class-based (`CallableTool`, `CallableTool2`) tools
- **Configuration**: Pydantic models for providers, models, and agent settings
- **Message Types**: ContentPart, Message, ToolCall with streaming merge support
[0.1.0]: https://github.com/yourusername/agentlite/releases/tag/v0.1.0

1013
agentlite/TEST_PLAN.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,279 @@
# LLM Client
Simple LLM client for direct LLM calls without agent overhead.
## Overview
The `LLMClient` provides a simple interface for making direct LLM calls, reusing the agentlite configuration system. This is useful when you don't need the full agent capabilities (tools, conversation history, etc.) and just want to call an LLM.
## Features
- **Simple Interface**: Just system prompt + user prompt → response
- **Configuration Reuse**: Uses existing `AgentConfig` for provider/model setup
- **Streaming Support**: Both non-streaming and streaming interfaces
- **Flexible Usage**: Use with config, direct provider, or simple functions
## Quick Start
### Method 1: Simple Function (Quickest)
```python
import asyncio
from agentlite import llm_complete
async def main():
response = await llm_complete(
user_prompt="What is Python?",
api_key="your-api-key",
model="gpt-4",
)
print(response)
asyncio.run(main())
```
### Method 2: Using Configuration
```python
import asyncio
from agentlite import LLMClient, AgentConfig, ProviderConfig, ModelConfig
async def main():
# Create configuration
config = AgentConfig(
providers={
"openai": ProviderConfig(api_key="your-api-key")
},
models={
"gpt4": ModelConfig(provider="openai", model="gpt-4")
},
default_model="gpt4",
)
# Create client
client = LLMClient(config)
# Make a call
response = await client.complete(
system_prompt="You are a helpful assistant.",
user_prompt="What is Python?"
)
print(response.content)
print(f"Model: {response.model}")
if response.usage:
print(f"Tokens: {response.usage.total}")
asyncio.run(main())
```
### Method 3: Direct Provider
```python
import asyncio
from agentlite import LLMClient, OpenAIProvider
async def main():
# Create provider directly
provider = OpenAIProvider(
api_key="your-api-key",
model="gpt-4",
temperature=0.8,
)
# Create client
client = LLMClient(provider=provider)
# Make a call
response = await client.complete(
user_prompt="Explain async/await",
system_prompt="You are a Python expert.",
)
print(response.content)
asyncio.run(main())
```
## Streaming
### Using Client
```python
async for chunk in client.stream(
user_prompt="Write a poem about AI",
system_prompt="You are a creative writer.",
):
print(chunk, end="")
```
### Using Function
```python
async for chunk in llm_stream(
user_prompt="Write a haiku",
api_key="your-api-key",
):
print(chunk, end="")
```
## API Reference
### LLMClient
```python
class LLMClient:
def __init__(
self,
config: Optional[AgentConfig] = None,
provider: Optional[ChatProvider] = None,
model: Optional[str] = None,
)
async def complete(
self,
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> LLMResponse
async def stream(
self,
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncIterator[str]
```
### LLMResponse
```python
class LLMResponse:
content: str # The response text
usage: TokenUsage | None # Token usage stats
model: str # Model name used
```
### Convenience Functions
```python
async def llm_complete(
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
api_key: Optional[str] = None,
model: str = "gpt-4",
base_url: str = "https://api.openai.com/v1",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> str
async def llm_stream(
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
api_key: Optional[str] = None,
model: str = "gpt-4",
base_url: str = "https://api.openai.com/v1",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncIterator[str]
```
## Configuration Options
### Temperature and Max Tokens
You can override temperature and max_tokens per call:
```python
response = await client.complete(
user_prompt="Creative writing task",
temperature=0.9, # More creative
max_tokens=500, # Limit response length
)
```
### Model Switching
When using `AgentConfig`, you can switch models:
```python
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="...")},
models={
"gpt4": ModelConfig(provider="openai", model="gpt-4"),
"gpt35": ModelConfig(provider="openai", model="gpt-3.5-turbo"),
},
default_model="gpt4",
)
# Use default model (gpt4)
client = LLMClient(config)
# Use specific model
client_gpt35 = LLMClient(config, model="gpt35")
```
## Comparison with Agent
| Feature | LLMClient | Agent |
|---------|-----------|-------|
| Tools | ❌ No | ✅ Yes |
| Conversation History | ❌ No | ✅ Yes |
| System Prompt | ✅ Yes | ✅ Yes |
| Configuration | ✅ Reuses AgentConfig | ✅ AgentConfig |
| Streaming | ✅ Yes | ✅ Yes |
| Use Case | Simple LLM calls | Complex agent workflows |
## Examples
### Translation
```python
async def translate(text: str, target_language: str) -> str:
response = await llm_complete(
user_prompt=f"Translate to {target_language}: {text}",
system_prompt="You are a translator. Return only the translation.",
api_key="your-api-key",
)
return response
```
### Code Review
```python
async def review_code(code: str) -> str:
client = LLMClient(config)
response = await client.complete(
user_prompt=f"Review this code:\n\n```python\n{code}\n```",
system_prompt="You are a code reviewer. Provide constructive feedback.",
)
return response.content
```
### Streaming Chat
```python
async def chat_stream(user_message: str):
async for chunk in client.stream(
user_prompt=user_message,
system_prompt="You are a helpful chat assistant.",
):
yield chunk
```
## Error Handling
```python
from agentlite.provider import APIConnectionError, APITimeoutError, APIStatusError
try:
response = await client.complete(user_prompt="Hello")
except APIConnectionError:
print("Failed to connect to API")
except APITimeoutError:
print("Request timed out")
except APIStatusError as e:
print(f"API error {e.status_code}: {e.message}")
```

271
agentlite/docs/tools.md Normal file
View File

@@ -0,0 +1,271 @@
# AgentLite Tool Suite
A comprehensive tool suite for AgentLite, inspired by kimi-cli's tools, with configuration support for enabling/disabling individual tools.
## Overview
This tool suite provides:
- **File Operations**: Read, write, edit, search files
- **Shell Execution**: Execute shell commands
- **Web Access**: Fetch URLs and search the web
- **Multi-Agent**: Task delegation and subagent creation
- **Utilities**: Todo lists and thinking tools
- **Configuration**: Fine-grained control over which tools are available
## Installation
The tool suite is included with AgentLite. No additional installation required.
## Quick Start
```python
from agentlite.tools import ConfigurableToolset, ToolSuiteConfig
from agentlite import Agent, OpenAIProvider
# Create toolset with default config (all tools enabled)
toolset = ConfigurableToolset()
# Create agent with tools
provider = OpenAIProvider(api_key="your-key", model="gpt-4")
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant.",
tools=toolset.tools,
)
```
## Configuration
### Basic Configuration
```python
from agentlite.tools import (
ToolSuiteConfig,
FileToolsConfig,
ShellToolsConfig,
)
# Disable specific tools
config = ToolSuiteConfig(
file_tools=FileToolsConfig(
tools={"WriteFile": False, "StrReplaceFile": False}
)
)
toolset = ConfigurableToolset(config)
```
### Disable Entire Tool Groups
```python
# Disable all shell tools
config = ToolSuiteConfig(
shell_tools=ShellToolsConfig(enabled=False)
)
toolset = ConfigurableToolset(config)
```
### Custom Tool Settings
```python
config = ToolSuiteConfig(
file_tools=FileToolsConfig(
max_lines=500,
max_bytes=50 * 1024, # 50KB
allow_write_outside_work_dir=False,
),
shell_tools=ShellToolsConfig(
timeout=60,
blocked_commands=["rm -rf", "sudo"],
),
)
```
### Dynamic Configuration
```python
# Create toolset
config = ToolSuiteConfig()
toolset = ConfigurableToolset(config)
# Disable tools and reload
config.file_tools.disable_tool("WriteFile")
config.shell_tools.enabled = False
toolset.reload()
```
## Available Tools
### File Tools
| Tool | Description | Config Options |
|------|-------------|----------------|
| `ReadFile` | Read text files with line numbers | `max_lines`, `max_bytes` |
| `WriteFile` | Write or append to files | `allow_write_outside_work_dir` |
| `StrReplaceFile` | Edit files using string replacement | `allow_write_outside_work_dir` |
| `Glob` | Search files using glob patterns | `max_glob_matches` |
| `Grep` | Search file contents with regex | - |
| `ReadMediaFile` | Read images and videos | `max_size_mb` |
### Shell Tools
| Tool | Description | Config Options |
|------|-------------|----------------|
| `Shell` | Execute shell commands | `timeout`, `blocked_commands` |
### Web Tools
| Tool | Description | Config Options |
|------|-------------|----------------|
| `FetchURL` | Fetch web page content | `timeout`, `user_agent` |
| `SearchWeb` | Search the web | `timeout` |
### Multi-Agent Tools
| Tool | Description | Config Options |
|------|-------------|----------------|
| `Task` | Delegate tasks to subagents | `max_steps` |
| `CreateSubagent` | Create custom subagents | - |
### Utility Tools
| Tool | Description |
|------|-------------|
| `SetTodoList` | Manage todo lists |
| `Think` | Record thinking steps |
## Safety Features
### Path Security
- Files outside the working directory require absolute paths
- Optional restriction on writing outside working directory
- Path traversal protection
### Shell Security
- Configurable command timeout
- Blocked command list
- No shell injection (uses `execve` style execution)
### Resource Limits
- File size limits
- Line count limits
- Glob match limits
- HTTP content size limits
## Examples
### Safe Configuration for Untrusted Agents
```python
from agentlite.tools import ToolSuiteConfig, FileToolsConfig, ShellToolsConfig
# Safe config - read-only file access, no shell
safe_config = ToolSuiteConfig(
file_tools=FileToolsConfig(
allow_write_outside_work_dir=False,
),
shell_tools=ShellToolsConfig(enabled=False),
)
toolset = ConfigurableToolset(safe_config)
```
### Using Individual Tools
```python
from agentlite.tools.file import ReadFile, Glob
from pathlib import Path
# Create tools directly
read_tool = ReadFile(work_dir=Path("."))
glob_tool = Glob(work_dir=Path("."))
# Use tools
result = await read_tool.read({"path": "README.md"})
if not result.is_error:
print(result.output)
result = await glob_tool.glob({"pattern": "*.py"})
if not result.is_error:
print(result.output)
```
### Configuration from File
```python
import json
from agentlite.tools import ToolSuiteConfig
# Load config from file
with open("tool_config.json") as f:
config_dict = json.load(f)
config = ToolSuiteConfig.model_validate(config_dict)
toolset = ConfigurableToolset(config)
```
## API Reference
### Config Classes
#### `ToolSuiteConfig`
Main configuration class for all tools.
```python
class ToolSuiteConfig(BaseModel):
file_tools: FileToolsConfig
shell_tools: ShellToolsConfig
web_tools: WebToolsConfig
multiagent_tools: MultiAgentToolsConfig
misc_tools: ToolGroupConfig
```
#### `FileToolsConfig`
```python
class FileToolsConfig(ToolGroupConfig):
max_lines: int = 1000
max_line_length: int = 2000
max_bytes: int = 100 * 1024
allow_write_outside_work_dir: bool = False
max_glob_matches: int = 1000
```
#### `ShellToolsConfig`
```python
class ShellToolsConfig(ToolGroupConfig):
timeout: int = 60
max_timeout: int = 300
blocked_commands: list[str] = []
```
#### `WebToolsConfig`
```python
class WebToolsConfig(ToolGroupConfig):
timeout: int = 30
user_agent: str = "Mozilla/5.0 ..."
max_content_length: int = 1024 * 1024
```
### ConfigurableToolset
```python
class ConfigurableToolset(SimpleToolset):
def __init__(
self,
config: ToolSuiteConfig | None = None,
work_dir: str | None = None,
)
def reload(self, config: ToolSuiteConfig | None = None) -> None
```
## License
MIT License - same as AgentLite.

View File

@@ -0,0 +1,80 @@
# AgentLite Examples
This directory contains examples demonstrating various features of AgentLite.
## Setup
Before running the examples, set your OpenAI API key:
```bash
export OPENAI_API_KEY="sk-..."
```
Or create a `.env` file:
```
OPENAI_API_KEY=sk-...
```
## Examples
### 1. Single Agent (`single_agent.py`)
Basic usage of a single agent with conversation history.
```bash
python examples/single_agent.py
```
### 2. Multi-Agent (`multi_agent.py`)
Multiple specialized agents working together on a task.
```bash
python examples/multi_agent.py
```
### 3. Custom Tools (`custom_tools.py`)
Defining and using custom tools with agents.
```bash
python examples/custom_tools.py
```
### 4. MCP Tools (`mcp_tools.py`)
Using tools from MCP (Model Context Protocol) servers.
**Prerequisites:**
- Node.js installed
- MCP filesystem server: `npm install -g @modelcontextprotocol/server-filesystem`
```bash
python examples/mcp_tools.py
```
## Creating Your Own
Use these examples as templates for your own applications:
```python
import asyncio
from agentlite import Agent, OpenAIProvider
async def main():
provider = OpenAIProvider(
api_key="your-api-key",
model="gpt-4",
)
agent = Agent(
provider=provider,
system_prompt="Your system prompt here.",
)
response = await agent.run("Your question here")
print(response)
asyncio.run(main())
```

View File

@@ -0,0 +1,118 @@
"""Example: Custom Tools
This example demonstrates how to define and use custom tools with agents.
"""
import asyncio
import os
from datetime import datetime
from pydantic import BaseModel
from agentlite import Agent, OpenAIProvider, tool
from agentlite.tool import CallableTool2, ToolOk
# Define a tool using the decorator
@tool()
async def get_current_time() -> str:
"""Get the current date and time."""
return datetime.now().isoformat()
@tool()
async def calculate(expression: str) -> str:
"""Evaluate a mathematical expression.
Args:
expression: The mathematical expression to evaluate (e.g., "2 + 2").
"""
try:
# Safe evaluation - only allow basic math operations
allowed_names = {
"abs": abs,
"max": max,
"min": min,
"pow": pow,
"round": round,
}
result = eval(expression, {"__builtins__": {}}, allowed_names)
return str(result)
except Exception as e:
return f"Error: {e}"
# Define a tool using CallableTool2 (type-safe)
class WeatherParams(BaseModel):
"""Parameters for weather tool."""
city: str
units: str = "celsius"
class GetWeather(CallableTool2[WeatherParams]):
"""Get weather information for a city."""
name = "get_weather"
description = "Get the current weather for a city."
params = WeatherParams
async def __call__(self, params: WeatherParams) -> ToolOk:
# This is a mock implementation
# In a real scenario, you would call a weather API
weather_data = {
"Beijing": {"temp": 22, "condition": "Sunny"},
"Shanghai": {"temp": 25, "condition": "Cloudy"},
"New York": {"temp": 18, "condition": "Rainy"},
"London": {"temp": 15, "condition": "Overcast"},
}
city = params.city
if city in weather_data:
data = weather_data[city]
temp = data["temp"]
if params.units == "fahrenheit":
temp = temp * 9 // 5 + 32
return ToolOk(
output=f"Weather in {city}: {data['condition']}, {temp}°{params.units[0].upper()}"
)
return ToolOk(output=f"Weather data not available for {city}")
async def main():
"""Run the custom tools example."""
# Create provider
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY", "your-api-key"),
model="gpt-4o-mini",
)
# Create agent with tools
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant with access to tools.",
tools=[
get_current_time,
calculate,
GetWeather(),
],
)
# Test tools
print("=== Testing Tools ===\n")
print("User: What time is it?")
response = await agent.run("What time is it?")
print(f"Agent: {response}\n")
print("User: What is 123 * 456?")
response = await agent.run("What is 123 * 456?")
print(f"Agent: {response}\n")
print("User: What's the weather in Beijing?")
response = await agent.run("What's the weather in Beijing?")
print(f"Agent: {response}\n")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,124 @@
"""Example demonstrating LLMClient usage.
This example shows how to use LLMClient for simple LLM calls
without the overhead of an Agent.
"""
import asyncio
from agentlite import LLMClient
from agentlite.config import AgentConfig, ProviderConfig, ModelConfig
async def main():
"""Run LLM client examples."""
# Example 1: Using simple function interface
print("=== Example 1: Simple Function ===")
print("Using llm_complete() function:")
# Note: This requires a valid API key
# response = await llm_complete(
# user_prompt="What is Python?",
# api_key="your-api-key",
# model="gpt-4",
# )
# print(response)
print("(Requires API key - uncomment to run)")
# Example 2: Using configuration-based client
print("\n=== Example 2: Configuration-Based Client ===")
config = AgentConfig(
name="simple_llm",
system_prompt="You are a helpful coding assistant.",
providers={
"openai": ProviderConfig(
type="openai",
api_key="your-api-key", # Replace with actual key
)
},
models={
"gpt4": ModelConfig(
provider="openai",
model="gpt-4",
temperature=0.7,
),
"gpt35": ModelConfig(
provider="openai",
model="gpt-3.5-turbo",
temperature=0.5,
),
},
default_model="gpt4",
)
# Create client
LLMClient(config)
# Make a call
# response = await client.complete(
# user_prompt="Explain async/await in Python",
# )
# print(f"Response: {response.content}")
# print(f"Model: {response.model}")
# if response.usage:
# print(f"Tokens: {response.usage.total}")
print("(Requires API key - uncomment to run)")
# Example 3: Streaming
print("\n=== Example 3: Streaming ===")
print("Using llm_stream() function:")
# async for chunk in llm_stream(
# user_prompt="Write a haiku about programming",
# api_key="your-api-key",
# ):
# print(chunk, end="")
print("\n(Requires API key - uncomment to run)")
# Example 4: Direct provider usage
print("\n=== Example 4: Direct Provider ===")
from agentlite import OpenAIProvider
provider = OpenAIProvider(
api_key="your-api-key",
model="gpt-4",
temperature=0.8,
)
LLMClient(provider=provider)
# response = await client.complete(
# user_prompt="What are the benefits of type hints?",
# system_prompt="You are a Python expert.",
# )
# print(response.content)
print("(Requires API key - uncomment to run)")
# Example 5: Model switching
print("\n=== Example 5: Model Switching ===")
# Use default model (gpt4)
# response1 = await client.complete(user_prompt="Hello!")
# Switch to different model
# client_gpt35 = LLMClient(config, model="gpt35")
# response2 = await client_gpt35.complete(user_prompt="Hello!")
print("(Requires API key - uncomment to run)")
print("\n=== Examples Complete ===")
print("To run these examples:")
print("1. Set your OpenAI API key")
print("2. Uncomment the example code")
print("3. Run: python examples/llm_client_example.py")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,79 @@
"""Example: MCP Tools
This example demonstrates how to use MCP (Model Context Protocol) tools
with AgentLite agents.
Note: This example requires an MCP server to be available.
"""
import asyncio
import os
from agentlite import Agent, MCPClient, OpenAIProvider
async def main():
"""Run the MCP tools example."""
# Create provider
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY", "your-api-key"),
model="gpt-4o-mini",
)
# Connect to MCP server
# This example uses the filesystem MCP server
# You can install it with: npm install -g @modelcontextprotocol/server-filesystem
print("Connecting to MCP server...")
async with MCPClient() as mcp:
# Connect via stdio
await mcp.connect_stdio(
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
)
# Load tools from MCP server
print("Loading MCP tools...")
mcp_tools = await mcp.load_tools()
print(f"Loaded {len(mcp_tools)} tools from MCP server")
# Create agent with MCP tools
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant with access to filesystem tools.",
tools=mcp_tools,
)
# Test MCP tools
print("\n=== Testing MCP Tools ===\n")
print("User: List files in /tmp")
response = await agent.run("List files in /tmp")
print(f"Agent: {response}\n")
print("User: Create a file called test.txt with 'Hello from AgentLite!'")
response = await agent.run(
"Create a file called test.txt with content 'Hello from AgentLite!'"
)
print(f"Agent: {response}\n")
print("User: Read the test.txt file")
response = await agent.run("Read the test.txt file")
print(f"Agent: {response}\n")
if __name__ == "__main__":
# Note: This example requires Node.js and the MCP filesystem server
# npm install -g @modelcontextprotocol/server-filesystem
print("Note: This example requires Node.js and @modelcontextprotocol/server-filesystem")
print("Install with: npm install -g @modelcontextprotocol/server-filesystem\n")
try:
asyncio.run(main())
except Exception as e:
print(f"Error: {e}")
print("\nMake sure you have:")
print("1. Node.js installed")
print("2. @modelcontextprotocol/server-filesystem installed globally")
print("3. OPENAI_API_KEY environment variable set")

View File

@@ -0,0 +1,54 @@
"""Example: Multi-Agent Usage
This example demonstrates using multiple agents working independently.
"""
import asyncio
import os
from agentlite import Agent, OpenAIProvider
async def main():
"""Run the multi-agent example."""
# Create provider
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY", "your-api-key"),
model="gpt-4o-mini",
)
# Create specialized agents
researcher = Agent(
provider=provider,
system_prompt="You are a research assistant. Provide factual, well-researched information.",
)
writer = Agent(
provider=provider,
system_prompt="You are a creative writer. Write engaging and clear content.",
)
critic = Agent(
provider=provider,
system_prompt="You are an editor. Review and improve content for clarity and accuracy.",
)
# Research phase
print("=== Research Phase ===")
topic = "artificial intelligence in healthcare"
research = await researcher.run(f"Research {topic}. Provide key points.")
print(f"Research:\n{research}\n")
# Writing phase
print("=== Writing Phase ===")
content = await writer.run(f"Write a blog post about {topic} using this research:\n{research}")
print(f"Draft:\n{content}\n")
# Review phase
print("=== Review Phase ===")
review = await critic.run(f"Review this blog post and suggest improvements:\n{content}")
print(f"Review:\n{review}\n")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,42 @@
"""Example: Single Agent Usage
This example demonstrates basic usage of the AgentLite Agent class.
"""
import asyncio
import os
from agentlite import Agent, OpenAIProvider
async def main():
"""Run the single agent example."""
# Create provider
provider = OpenAIProvider(
api_key=os.getenv("OPENAI_API_KEY", "your-api-key"),
model="gpt-4o-mini",
)
# Create agent
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant. Be concise.",
)
# Run conversation
print("User: What is Python?")
response = await agent.run("What is Python?")
print(f"Agent: {response}\n")
print("User: What are its main features?")
response = await agent.run("What are its main features?")
print(f"Agent: {response}\n")
# Show conversation history
print("--- Conversation History ---")
for msg in agent.history:
print(f"{msg.role}: {msg.extract_text()[:100]}...")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,68 @@
---
name: code-reviewer
description: Review code for bugs, style issues, security vulnerabilities, and best practices. Use when the user asks to review, check, or audit code.
type: standard
---
# Code Reviewer
A comprehensive code review skill that checks for common issues and provides actionable feedback.
## Review Checklist
### 1. Correctness
- Check for logical errors
- Verify edge cases are handled
- Look for off-by-one errors
- Check null/None handling
- Verify error handling paths
### 2. Style & Readability
- Naming conventions (clear, descriptive names)
- Code organization and structure
- Comments where needed (not obvious code)
- Consistent formatting
- Function/class length
### 3. Performance
- Inefficient algorithms (O(n²) when O(n) possible)
- Unnecessary object creation
- Memory leaks
- Redundant operations
### 4. Security
- SQL injection vulnerabilities
- XSS vulnerabilities (for web code)
- Hardcoded secrets/passwords
- Unsafe deserialization
- Path traversal risks
### 5. Best Practices
- DRY principle (Don't Repeat Yourself)
- SOLID principles
- Proper use of language features
- Test coverage considerations
## Output Format
Provide your review in this structure:
```
## Summary
Brief overall assessment
## Critical Issues
- Issue 1: Description and fix
- Issue 2: Description and fix
## Warnings
- Warning 1: Description and suggestion
## Suggestions
- Suggestion 1: How to improve
## Positive Notes
- What's done well
```
Be constructive and specific. Include code examples for suggested fixes.

View File

@@ -0,0 +1,63 @@
---
name: release-process
description: Execute the release workflow including version checks, changelog updates, and PR creation. Use when the user wants to create a new release or version.
type: flow
---
# Release Process
Follow this structured workflow to create a new release.
## Flow
```mermaid
flowchart TD
BEGIN(( )) --> CHECK[Check for uncommitted changes]
CHECK --> CHANGES{Changes?}
CHANGES -->|Yes| COMMIT[Commit or stash changes]
CHANGES -->|No| VERSION{Version type?}
COMMIT --> VERSION
VERSION -->|Patch| UPDATE_PATCH[Update patch version]
VERSION -->|Minor| UPDATE_MINOR[Update minor version]
VERSION -->|Major| UPDATE_MAJOR[Update major version]
UPDATE_PATCH --> CHANGELOG[Update CHANGELOG.md]
UPDATE_MINOR --> CHANGELOG
UPDATE_MAJOR --> CHANGELOG
CHANGELOG --> BRANCH[Create release branch]
BRANCH --> PR[Create Pull Request]
PR --> END(( ))
```
## Node Details
### Check for uncommitted changes
Run `git status` and check if there are any uncommitted changes.
### Commit or stash changes
Ask the user whether to commit the changes or stash them for later.
### Version type
Ask the user what type of release this is:
- **Patch**: Bug fixes (0.0.X)
- **Minor**: New features, backward compatible (0.X.0)
- **Major**: Breaking changes (X.0.0)
### Update version
Update the version number in:
- `pyproject.toml` or `package.json`
- Any other version files
### Update CHANGELOG
Add a new section to CHANGELOG.md with:
- Version number and date
- List of changes
- Breaking changes (if any)
- Migration notes (if needed)
### Create release branch
Create a new branch: `release/vX.Y.Z`
### Create Pull Request
Open a PR with:
- Title: "Release vX.Y.Z"
- Description summarizing the changes

View File

@@ -0,0 +1,86 @@
"""Example demonstrating the skills system for AgentLite.
This example shows how to use skills with an Agent.
"""
import asyncio
from pathlib import Path
from agentlite.skills import discover_skills, index_skills_by_name
async def main():
"""Run skills example."""
print("=" * 60)
print("AgentLite Skills Example")
print("=" * 60)
# Discover skills from examples directory
skills_dir = Path(__file__).parent / "skills"
skills = discover_skills(skills_dir)
print(f"\nDiscovered {len(skills)} skill(s):")
for skill in skills:
print(f" - {skill.name}: {skill.description}")
print(f" Type: {skill.type}")
if skill.flow:
print(f" Flow nodes: {len(skill.flow.nodes)}")
# Index skills by name
skill_index = index_skills_by_name(skills)
print(f"\nIndexed {len(skill_index)} skill(s)")
# Create agent (would need API key to actually run)
print("\n" + "-" * 40)
print("To use skills with an agent:")
print("-" * 40)
code = """
# Create provider
provider = OpenAIProvider(api_key="your-key", model="gpt-4")
# Create agent
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant with access to skills.",
)
# Create skill tool
skill_tool = SkillTool(skill_index, parent_agent=agent)
# Add skill tool to agent
agent.tools.add(skill_tool)
# Now the agent can use skills!
# The agent will see available skills in its context
# Example usage:
response = await agent.run("Review this Python code: def add(a, b): return a + b")
# The agent may choose to use the code-reviewer skill
"""
print(code)
print("\n" + "=" * 60)
print("Key Concepts:")
print("=" * 60)
print("1. Skills are defined in SKILL.md files")
print("2. YAML frontmatter specifies name, description, and type")
print("3. Standard skills load the markdown as a prompt")
print("4. Flow skills execute a structured flowchart")
print("5. Skills are discovered from directories")
print("6. SkillTool allows agents to execute skills")
print("\nSkill Format (SKILL.md):")
print(""" ---
name: skill-name
description: When to use this skill...
type: standard | flow
---
# Skill Content
Instructions for the skill...
""")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,168 @@
"""Example demonstrating subagent usage in AgentLite.
This example shows how to create a parent agent with subagents
and delegate tasks to them using the Task tool.
"""
import asyncio
from agentlite import Agent, OpenAIProvider
from agentlite.tools.multiagent.task import Task
async def main():
"""Run subagent example."""
print("=" * 60)
print("AgentLite Subagent Example")
print("=" * 60)
# Note: This example requires a valid API key
# Replace with your actual API key to run
api_key = "your-api-key"
if api_key == "your-api-key":
print("\nNOTE: Set your API key to run this example")
print("Example code is shown below:\n")
print("-" * 40)
# Create provider
provider = OpenAIProvider(api_key=api_key, model="gpt-4")
# Example 1: Create subagents manually
print("\n=== Example 1: Manual Subagent Setup ===")
# Create parent agent with empty labor market
parent = Agent(
provider=provider,
system_prompt="You are a coordinator agent that delegates tasks to specialists.",
name="coordinator",
)
# Create subagents
coder = Agent(
provider=provider,
system_prompt="You are a coding specialist. Write clean, well-documented code.",
name="coder",
)
reviewer = Agent(
provider=provider,
system_prompt="You are a code reviewer. Provide constructive feedback.",
name="reviewer",
)
# Register subagents with parent
parent.add_subagent("coder", coder, "Writes code", dynamic=False)
parent.add_subagent("reviewer", reviewer, "Reviews code", dynamic=False)
# Add Task tool to parent
parent.tools.add(Task(labor_market=parent.labor_market))
print("Created parent agent with subagents:")
print(" - coder: Writes code")
print(" - reviewer: Reviews code")
# Example 2: Using subagents
print("\n=== Example 2: Delegating Tasks ===")
# Parent agent delegates to coder
# response = await parent.run(
# "I need a Python function to calculate fibonacci numbers. "
# "Use the coder subagent to write it."
# )
print("(Requires API key - uncomment to run)")
# Example 3: Nested subagents (hierarchy)
print("\n=== Example 3: Hierarchical Structure ===")
# Create a team lead with team members as subagents
team_lead = Agent(
provider=provider,
system_prompt="You are a team lead. Coordinate work among your team members.",
name="team_lead",
)
# Create team members
backend_dev = Agent(
provider=provider,
system_prompt="You are a backend developer. Focus on API design and database.",
name="backend_dev",
)
frontend_dev = Agent(
provider=provider,
system_prompt="You are a frontend developer. Focus on UI/UX.",
name="frontend_dev",
)
tester = Agent(
provider=provider,
system_prompt="You are a QA engineer. Write test cases and find bugs.",
name="tester",
)
# Add subagents to team lead
team_lead.add_subagent("backend", backend_dev, "Backend development")
team_lead.add_subagent("frontend", frontend_dev, "Frontend development")
team_lead.add_subagent("qa", tester, "Quality assurance")
# Add Task tool
team_lead.tools.add(Task(labor_market=team_lead.labor_market))
print("Created team hierarchy:")
print(" team_lead/")
print(" ├── backend: Backend development")
print(" ├── frontend: Frontend development")
print(" └── qa: Quality assurance")
# Example 4: Dynamic subagents
print("\n=== Example 4: Dynamic Subagents ===")
# Create subagent dynamically
specialist = Agent(
provider=provider,
system_prompt="You are a specialist for a specific task.",
name="specialist",
)
# Add as dynamic subagent
team_lead.add_subagent("specialist", specialist, "Temporary specialist", dynamic=True)
print("Added dynamic subagent 'specialist' to team_lead")
# Example 5: Agent discovery
print("\n=== Example 5: Agent Discovery ===")
print(f"Team lead's subagents: {team_lead.labor_market.list_subagents()}")
print(f"Descriptions: {team_lead.labor_market.subagent_descriptions}")
# Check if subagent exists
if "backend" in team_lead.labor_market:
print("Backend subagent is available")
# Get specific subagent
backend = team_lead.get_subagent("backend")
print(f"Backend agent name: {backend.name if backend else 'not found'}")
# Example 6: Create subagent copy
print("\n=== Example 6: Subagent Copy ===")
# Create a copy of parent for use as subagent elsewhere
parent_copy = parent.create_subagent_copy()
print(f"Created copy of parent: {parent_copy.name}")
print(f"Copy has empty labor market: {len(parent_copy.labor_market) == 0}")
print("\n" + "=" * 60)
print("Examples Complete")
print("=" * 60)
print("\nKey Concepts:")
print("1. Parent agent holds subagents in LaborMarket")
print("2. Task tool allows parent to delegate to subagents")
print("3. Subagents have independent history and context")
print("4. Fixed subagents are defined at setup")
print("5. Dynamic subagents can be added at runtime")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,130 @@
"""Example demonstrating the configurable tool suite for AgentLite.
This example shows how to use the tool suite with configuration
to enable/disable specific tools.
"""
import asyncio
from pathlib import Path
from agentlite.tools import (
ConfigurableToolset,
ToolSuiteConfig,
FileToolsConfig,
ShellToolsConfig,
)
async def main():
"""Demonstrate the configurable tool suite."""
# Example 1: Default configuration (all tools enabled)
print("=== Example 1: Default Configuration ===")
config = ToolSuiteConfig()
toolset = ConfigurableToolset(config)
print(f"Enabled tools: {len(toolset.tools)}")
for tool in toolset.tools:
print(f" - {tool.name}")
# Example 2: Disable specific tools
print("\n=== Example 2: Disable WriteFile ===")
config = ToolSuiteConfig(
file_tools=FileToolsConfig(
tools={"WriteFile": False} # Disable WriteFile
)
)
toolset = ConfigurableToolset(config)
print(f"Enabled tools: {len(toolset.tools)}")
for tool in toolset.tools:
print(f" - {tool.name}")
# Example 3: Disable entire tool groups
print("\n=== Example 3: Disable Shell Tools ===")
config = ToolSuiteConfig(shell_tools=ShellToolsConfig(enabled=False))
toolset = ConfigurableToolset(config)
print(f"Enabled tools: {len(toolset.tools)}")
for tool in toolset.tools:
print(f" - {tool.name}")
# Example 4: Custom file tool settings
print("\n=== Example 4: Custom File Tool Settings ===")
config = ToolSuiteConfig(
file_tools=FileToolsConfig(
max_lines=500,
max_bytes=50 * 1024, # 50KB
allow_write_outside_work_dir=True,
)
)
toolset = ConfigurableToolset(config)
print("File tool settings:")
print(f" Max lines: {config.file_tools.max_lines}")
print(f" Max bytes: {config.file_tools.max_bytes}")
print(f" Allow outside work dir: {config.file_tools.allow_write_outside_work_dir}")
# Example 5: Using with an Agent
print("\n=== Example 5: Using with Agent ===")
# Create a safe configuration (no shell, no write outside work dir)
ToolSuiteConfig(
file_tools=FileToolsConfig(
allow_write_outside_work_dir=False,
),
shell_tools=ShellToolsConfig(enabled=False),
)
# This would require an API key to actually run
# provider = OpenAIProvider(api_key="your-api-key", model="gpt-4")
# agent = Agent(
# provider=provider,
# system_prompt="You are a helpful assistant with file access.",
# tools=ConfigurableToolset(safe_config).tools,
# )
print("Safe configuration created:")
print(" - Shell tools: DISABLED")
print(" - Write outside work dir: DISABLED")
print(" - Read file: ENABLED")
print(" - Glob/Grep: ENABLED")
# Example 6: Dynamic configuration reload
print("\n=== Example 6: Dynamic Reload ===")
config = ToolSuiteConfig()
toolset = ConfigurableToolset(config)
print(f"Initial tools: {len(toolset.tools)}")
# Disable some tools and reload
config.file_tools.disable_tool("WriteFile")
config.shell_tools.enabled = False
toolset.reload()
print(f"After reload: {len(toolset.tools)}")
for tool in toolset.tools:
print(f" - {tool.name}")
# Example 7: Using individual tools directly
print("\n=== Example 7: Direct Tool Usage ===")
from agentlite.tools.file import ReadFile, Glob
# Create tools directly
read_tool = ReadFile(work_dir=Path("."))
glob_tool = Glob(work_dir=Path("."))
# Use ReadFile
result = await read_tool.read({"path": "README.md"})
if not result.is_error:
print(f"README.md: {len(result.output)} characters")
else:
print(f"Could not read README.md: {result.message}")
# Use Glob
result = await glob_tool.glob({"pattern": "*.py"})
if not result.is_error:
files = result.output.split("\n") if result.output else []
print(f"Python files found: {len(files)}")
else:
print(f"Glob error: {result.message}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,116 @@
"""AgentLite - A lightweight, async-first Agent component library.
AgentLite provides clean abstractions for building LLM-powered agents with
OpenAI-compatible APIs, supporting tools (including MCP), streaming, and
multi-agent usage.
Example:
>>> import asyncio
>>> from agentlite import Agent, OpenAIProvider
>>>
>>> async def main():
... provider = OpenAIProvider(api_key="sk-...", model="gpt-4")
... agent = Agent(provider=provider, system_prompt="You are helpful.")
... response = await agent.run("Hello!")
... print(response)
>>>
>>> asyncio.run(main())
"""
__version__ = "0.1.0"
# Core types
from agentlite.message import (
ContentPart,
Message,
Role,
TextPart,
ImageURLPart,
AudioURLPart,
ToolCall,
ToolCallPart,
)
from agentlite.tool import (
Tool,
ToolResult,
ToolOk,
ToolError,
CallableTool,
CallableTool2,
SimpleToolset,
tool,
)
from agentlite.provider import (
ChatProvider,
StreamedMessage,
TokenUsage,
ChatProviderError,
APIConnectionError,
APITimeoutError,
APIStatusError,
)
# Configuration
from agentlite.config import (
ProviderConfig,
ModelConfig,
AgentConfig,
)
# Agent
from agentlite.agent import Agent
# MCP
from agentlite.mcp import MCPClient
# OpenAI Provider
from agentlite.providers.openai import OpenAIProvider
# LLM Client
from agentlite.llm_client import LLMClient, LLMResponse, llm_complete, llm_stream
__all__ = [
# Version
"__version__",
# Message types
"ContentPart",
"Message",
"Role",
"TextPart",
"ImageURLPart",
"AudioURLPart",
"ToolCall",
"ToolCallPart",
# Tool types
"Tool",
"ToolResult",
"ToolOk",
"ToolError",
"CallableTool",
"CallableTool2",
"SimpleToolset",
"tool",
# Provider types
"ChatProvider",
"StreamedMessage",
"TokenUsage",
"ChatProviderError",
"APIConnectionError",
"APITimeoutError",
"APIStatusError",
# Configuration
"ProviderConfig",
"ModelConfig",
"AgentConfig",
# Agent
"Agent",
# MCP
"MCPClient",
# Providers
"OpenAIProvider",
# LLM Client
"LLMClient",
"LLMResponse",
"llm_complete",
"llm_stream",
]

View File

@@ -0,0 +1,452 @@
"""Main Agent class for AgentLite.
This module provides the core Agent class that orchestrates LLM interactions,
tool calling, and conversation management.
"""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Sequence
from typing import TYPE_CHECKING
from agentlite.message import (
ContentPart,
Message,
TextPart,
ToolCall,
ToolCallPart,
)
from agentlite.provider import ChatProvider
from agentlite.tool import SimpleToolset, ToolResult, ToolType
from agentlite.labor_market import LaborMarket
if TYPE_CHECKING:
pass
class Agent:
"""An LLM agent that can use tools and maintain conversation history.
The Agent class is the main interface for interacting with LLMs. It handles:
- Sending messages to the LLM
- Managing tool calls and execution
- Maintaining conversation history
- Streaming responses
Attributes:
provider: The LLM provider to use.
system_prompt: The system prompt for the agent.
tools: The toolset containing available tools.
history: The conversation history.
Example:
>>> provider = OpenAIProvider(api_key="sk-...", model="gpt-4")
>>> agent = Agent(
... provider=provider,
... system_prompt="You are a helpful assistant.",
... )
>>> response = await agent.run("Hello!")
>>> print(response)
"""
def __init__(
self,
provider: ChatProvider,
system_prompt: str = "You are a helpful assistant.",
tools: Sequence[ToolType] | None = None,
max_iterations: int = 80,
labor_market: LaborMarket | None = None,
name: str = "agent",
allow_subagents: bool = False,
):
"""Initialize the agent.
Args:
provider: The LLM provider to use.
system_prompt: The system prompt for the agent.
tools: Optional sequence of tools to make available.
max_iterations: Maximum number of tool call iterations per request.
labor_market: Optional LaborMarket for managing subagents.
name: Name of the agent (for identification in subagent hierarchies).
allow_subagents: Whether this agent is allowed to register subagents.
"""
self.provider = provider
self.system_prompt = system_prompt
self.tools = SimpleToolset(tools)
self.max_iterations = max_iterations
self.labor_market = labor_market or LaborMarket()
self.name = name
self.allow_subagents = allow_subagents
self._history: list[Message] = []
@property
def history(self) -> list[Message]:
"""Get the conversation history.
Returns:
A copy of the conversation history.
"""
return self._history.copy()
def clear_history(self) -> None:
"""Clear the conversation history."""
self._history.clear()
def add_message(self, message: Message) -> None:
"""Add a message to the history.
Args:
message: The message to add.
"""
self._history.append(message)
async def run(
self,
message: str,
*,
stream: bool = False,
) -> str | AsyncIterator[str]:
"""Run the agent with a user message.
This method sends the message to the LLM and handles any tool calls
that the model requests. It continues the conversation until the
model produces a final response without tool calls.
Args:
message: The user message.
stream: Whether to stream the response.
Returns:
If stream=False: The complete response as a string.
If stream=True: An async iterator yielding response chunks.
Example:
# Non-streaming
>>> response = await agent.run("What is 2 + 2?")
>>> print(response)
# Streaming
>>> async for chunk in await agent.run("Tell me a story", stream=True):
... print(chunk, end="")
"""
# Add user message to history
self._history.append(Message(role="user", content=message))
if stream:
return self._run_streaming()
else:
return await self._run_non_streaming()
async def _run_non_streaming(self) -> str:
"""Run the agent in non-streaming mode.
Returns:
The complete response as a string.
"""
iterations = 0
tool_calls: list[ToolCall] = []
while iterations < self.max_iterations:
iterations += 1
# Generate response
stream = await self.provider.generate(
system_prompt=self.system_prompt,
tools=self.tools.tools,
history=self._history,
)
# Collect response parts
response_parts: list[ContentPart] = []
tool_calls: list[ToolCall] = []
async for part in stream:
if isinstance(part, ToolCall):
tool_calls.append(part)
elif isinstance(part, ToolCallPart):
if tool_calls:
tool_calls[-1].merge_in_place(part)
elif isinstance(part, ContentPart):
response_parts.append(part)
# Extract text from response
response_text = ""
for part in response_parts:
if isinstance(part, TextPart):
response_text += part.text
# Add assistant message to history
self._history.append(
Message(
role="assistant",
content=response_parts,
tool_calls=tool_calls if tool_calls else None,
)
)
# If no tool calls, we're done
if not tool_calls:
return response_text
# Execute tool calls
tool_results = await self._execute_tool_calls(tool_calls)
# Add tool results to history
for result in tool_results:
self._history.append(
Message(
role="tool",
content=result.output,
tool_call_id=result.tool_call_id,
)
)
# Max iterations reached
last_tools_msg = ""
try:
if tool_calls:
tool_names = [tc.function.name for tc in tool_calls if hasattr(tc, "function")]
if tool_names:
last_tools_msg = f" Last tools called: {', '.join(tool_names)}."
except Exception:
pass
return (
f"Maximum tool call iterations reached ({self.max_iterations})."
f"{last_tools_msg}"
f" Consider increasing max_iterations or breaking the task into smaller steps."
)
async def _run_streaming(self) -> AsyncIterator[str]:
"""Run the agent in streaming mode.
Yields:
Response text chunks.
"""
iterations = 0
tool_calls: list[ToolCall] = []
while iterations < self.max_iterations:
iterations += 1
# Generate response
stream = await self.provider.generate(
system_prompt=self.system_prompt,
tools=self.tools.tools,
history=self._history,
)
# Collect response parts and yield text
response_parts: list[ContentPart] = []
tool_calls: list[ToolCall] = []
async for part in stream:
if isinstance(part, ToolCall):
tool_calls.append(part)
elif isinstance(part, ToolCallPart):
if tool_calls:
tool_calls[-1].merge_in_place(part)
elif isinstance(part, ContentPart):
response_parts.append(part)
if isinstance(part, TextPart):
yield part.text
# Add assistant message to history
self._history.append(
Message(
role="assistant",
content=response_parts,
tool_calls=tool_calls if tool_calls else None,
)
)
# If no tool calls, we're done
if not tool_calls:
return
# Execute tool calls
tool_results = await self._execute_tool_calls(tool_calls)
# Add tool results to history
for result in tool_results:
self._history.append(
Message(
role="tool",
content=result.output,
tool_call_id=result.tool_call_id,
)
)
# Max iterations reached
last_tools_msg = ""
try:
if tool_calls:
tool_names = [tc.function.name for tc in tool_calls if hasattr(tc, "function")]
if tool_names:
last_tools_msg = f" Last tools called: {', '.join(tool_names)}."
except Exception:
pass
yield (
f"Maximum tool call iterations reached ({self.max_iterations})."
f"{last_tools_msg}"
f" Consider increasing max_iterations or breaking the task into smaller steps."
)
async def _execute_tool_calls(
self,
tool_calls: list[ToolCall],
) -> list[_ToolResult]:
"""Execute a list of tool calls.
Args:
tool_calls: The tool calls to execute.
Returns:
List of tool results.
"""
results: list[_ToolResult] = []
# Execute all tool calls concurrently
futures = [self.tools.handle(tc) for tc in tool_calls]
for tc, future in zip(tool_calls, futures, strict=False):
try:
if asyncio.isfuture(future):
result = await future
else:
result = future
results.append(
_ToolResult(
tool_call_id=tc.id,
output=result.output if isinstance(result, ToolResult) else str(result),
is_error=result.is_error if isinstance(result, ToolResult) else False,
)
)
except Exception as e:
results.append(
_ToolResult(
tool_call_id=tc.id,
output=str(e),
is_error=True,
)
)
return results
async def generate(
self,
message: str,
) -> Message:
"""Generate a single response without tool calling loop.
This method sends a message to the LLM and returns the response
without executing any tool calls. This is useful when you want
to handle tool calls manually.
Args:
message: The user message.
Returns:
The assistant's response message.
"""
# Add user message to history
self._history.append(Message(role="user", content=message))
# Generate response
stream = await self.provider.generate(
system_prompt=self.system_prompt,
tools=self.tools.tools,
history=self._history,
)
# Collect response parts
response_parts: list[ContentPart] = []
tool_calls: list[ToolCall] = []
async for part in stream:
if isinstance(part, ToolCall):
tool_calls.append(part)
elif isinstance(part, ToolCallPart):
if tool_calls:
tool_calls[-1].merge_in_place(part)
elif isinstance(part, ContentPart):
response_parts.append(part)
# Create response message
response = Message(
role="assistant",
content=response_parts,
tool_calls=tool_calls if tool_calls else None,
)
# Add to history
self._history.append(response)
return response
def add_subagent(
self,
name: str,
agent: Agent,
description: str,
dynamic: bool = False,
) -> None:
"""Add a subagent to this agent's labor market.
Args:
name: Unique name for the subagent
agent: The Agent instance to add
description: Description of what the subagent does
dynamic: If True, add as dynamic subagent; otherwise fixed
"""
if not self.allow_subagents:
raise RuntimeError("Subagent delegation is disabled for this agent runtime.")
if dynamic:
self.labor_market.add_dynamic_subagent(name, agent)
else:
self.labor_market.add_fixed_subagent(name, agent, description)
def get_subagent(self, name: str) -> Agent | None:
"""Get a subagent by name.
Args:
name: Name of the subagent
Returns:
The subagent Agent if found, None otherwise
"""
return self.labor_market.get_subagent(name)
def create_subagent_copy(self) -> Agent:
"""Create a copy of this agent for use as a subagent.
The copy will have:
- Same provider
- Independent history (empty)
- Empty labor market (subagents cannot have their own subagents by default)
Returns:
A new Agent instance configured as a subagent
"""
return Agent(
provider=self.provider,
system_prompt=self.system_prompt,
tools=list(self.tools._tools.values()),
max_iterations=self.max_iterations,
labor_market=LaborMarket(), # Empty labor market
allow_subagents=False,
name=f"{self.name}_sub",
)
class _ToolResult:
"""Internal class for tool execution results."""
def __init__(self, tool_call_id: str, output: str, is_error: bool):
self.tool_call_id = tool_call_id
self.output = output
self.is_error = is_error

View File

@@ -0,0 +1,201 @@
"""Configuration models for AgentLite.
This module provides Pydantic-based configuration models for providers,
models, and agent settings.
"""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel, Field, SecretStr, model_validator
ProviderType = Literal["openai", "anthropic", "google", "custom"]
ModelCapability = Literal[
"streaming",
"tool_calling",
"vision",
"json_mode",
"function_calling",
]
class ProviderConfig(BaseModel):
"""Configuration for an LLM provider.
Attributes:
type: The provider type (openai, anthropic, etc.)
base_url: The API base URL
api_key: The API key (stored securely)
headers: Additional headers to include in requests
timeout: Request timeout in seconds
Example:
>>> config = ProviderConfig(
... type="openai",
... base_url="https://api.openai.com/v1",
... api_key="sk-...",
... )
"""
type: ProviderType = "openai"
base_url: str = "https://api.openai.com/v1"
api_key: SecretStr
headers: dict[str, str] = Field(default_factory=dict)
timeout: float = 60.0
@model_validator(mode="after")
def validate_base_url(self) -> "ProviderConfig":
"""Validate that base_url is a valid URL."""
if not self.base_url.startswith(("http://", "https://")):
raise ValueError("base_url must start with http:// or https://")
return self
class ModelConfig(BaseModel):
"""Configuration for an LLM model.
Attributes:
provider: Name of the provider to use
model: The model name/ID
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
capabilities: Set of model capabilities
Example:
>>> config = ModelConfig(
... provider="openai",
... model="gpt-4",
... temperature=0.7,
... )
"""
provider: str
model: str
max_tokens: Optional[int] = None
temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
capabilities: set[ModelCapability] = Field(default_factory=set)
@model_validator(mode="after")
def validate_provider(self) -> "ModelConfig":
"""Validate provider is not empty."""
if not self.provider:
raise ValueError("provider must not be empty")
return self
class ToolConfig(BaseModel):
"""Configuration for tool usage.
Attributes:
max_iterations: Maximum number of tool call iterations
timeout: Timeout for tool execution in seconds
"""
max_iterations: int = Field(default=80, ge=1, le=100)
timeout: float = 60.0
class AgentConfig(BaseModel):
"""Complete configuration for an Agent.
This combines provider, model, and behavior settings into a single
configuration object.
Attributes:
name: Optional name for the agent
system_prompt: The system prompt to use
providers: Dictionary of provider configurations
models: Dictionary of model configurations
default_model: Name of the default model to use
tools: Tool configuration
max_history: Maximum number of messages to keep in history
Example:
>>> config = AgentConfig(
... name="my_agent",
... system_prompt="You are a helpful assistant.",
... providers={
... "openai": ProviderConfig(
... type="openai",
... api_key="sk-...",
... )
... },
... models={
... "gpt4": ModelConfig(
... provider="openai",
... model="gpt-4",
... )
... },
... default_model="gpt4",
... )
"""
name: str = "agent"
system_prompt: str = "You are a helpful assistant."
providers: dict[str, ProviderConfig] = Field(default_factory=dict)
models: dict[str, ModelConfig] = Field(default_factory=dict)
default_model: str = "default"
tools: ToolConfig = Field(default_factory=ToolConfig)
max_history: int = Field(default=100, ge=1)
@model_validator(mode="after")
def validate_default_model(self) -> "AgentConfig":
"""Validate that default_model exists in models."""
if self.default_model and self.default_model not in self.models:
raise ValueError(f"default_model '{self.default_model}' not found in models")
return self
@model_validator(mode="after")
def validate_model_providers(self) -> "AgentConfig":
"""Validate that all model providers exist."""
for model_name, model_config in self.models.items():
if model_config.provider not in self.providers:
raise ValueError(
f"Model '{model_name}' references unknown provider '{model_config.provider}'"
)
return self
def get_provider_config(self, model_name: Optional[str] = None) -> ProviderConfig:
"""Get the provider config for a model.
Args:
model_name: Name of the model. If None, uses default_model.
Returns:
The provider configuration for the model.
Raises:
ValueError: If the model or provider is not found.
"""
model_name = model_name or self.default_model
if model_name not in self.models:
raise ValueError(f"Model '{model_name}' not found")
model_config = self.models[model_name]
if model_config.provider not in self.providers:
raise ValueError(f"Provider '{model_config.provider}' not found")
return self.providers[model_config.provider]
def get_model_config(self, model_name: Optional[str] = None) -> ModelConfig:
"""Get the configuration for a model.
Args:
model_name: Name of the model. If None, uses default_model.
Returns:
The model configuration.
Raises:
ValueError: If the model is not found.
"""
model_name = model_name or self.default_model
if model_name not in self.models:
raise ValueError(f"Model '{model_name}' not found")
return self.models[model_name]

View File

@@ -0,0 +1,182 @@
"""Labor Market for managing subagents in AgentLite.
This module provides the LaborMarket class for managing subagents
in a hierarchical agent architecture, similar to kimi-cli's approach.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from agentlite.agent import Agent
class LaborMarket:
"""Manages subagents for a parent agent.
The LaborMarket acts as a registry for subagents, allowing a parent
agent to delegate tasks to its children. It supports both fixed
(pre-defined) and dynamic (runtime-created) subagents.
This design follows kimi-cli's architecture where:
- Fixed subagents are defined in configuration and loaded at startup
- Dynamic subagents can be created at runtime using CreateSubagent tool
- Subagents can be retrieved by name for task delegation
Example:
>>> market = LaborMarket()
>>> market.add_fixed_subagent("coder", coder_agent, "Writes code")
>>> market.add_dynamic_subagent("temp", temp_agent)
>>> agent = market.get_subagent("coder")
"""
def __init__(self):
"""Initialize an empty labor market."""
self._fixed_subagents: dict[str, Agent] = {}
self._fixed_subagent_descs: dict[str, str] = {}
self._dynamic_subagents: dict[str, Agent] = {}
@property
def subagents(self) -> dict[str, Agent]:
"""Get all subagents (both fixed and dynamic).
Returns:
Dictionary mapping subagent names to Agent instances.
"""
return {**self._fixed_subagents, **self._dynamic_subagents}
@property
def fixed_subagents(self) -> dict[str, Agent]:
"""Get fixed (pre-defined) subagents.
Returns:
Dictionary of fixed subagents.
"""
return self._fixed_subagents.copy()
@property
def dynamic_subagents(self) -> dict[str, Agent]:
"""Get dynamic (runtime-created) subagents.
Returns:
Dictionary of dynamic subagents.
"""
return self._dynamic_subagents.copy()
@property
def subagent_descriptions(self) -> dict[str, str]:
"""Get descriptions of all subagents.
Returns:
Dictionary mapping subagent names to their descriptions.
Only fixed subagents have descriptions.
"""
return self._fixed_subagent_descs.copy()
def add_fixed_subagent(self, name: str, agent: Agent, description: str) -> None:
"""Add a fixed subagent.
Fixed subagents are defined in configuration and loaded at startup.
They typically have their own LaborMarket (for isolation).
Args:
name: Unique name for the subagent
agent: The Agent instance
description: Description of what the subagent does
Raises:
ValueError: If a subagent with the same name already exists.
"""
if name in self.subagents:
raise ValueError(f"Subagent '{name}' already exists")
self._fixed_subagents[name] = agent
self._fixed_subagent_descs[name] = description
def add_dynamic_subagent(self, name: str, agent: Agent) -> None:
"""Add a dynamic subagent.
Dynamic subagents are created at runtime, typically using the
CreateSubagent tool. They share the parent's LaborMarket.
Args:
name: Unique name for the subagent
agent: The Agent instance
Raises:
ValueError: If a subagent with the same name already exists.
"""
if name in self.subagents:
raise ValueError(f"Subagent '{name}' already exists")
self._dynamic_subagents[name] = agent
def get_subagent(self, name: str) -> Optional[Agent]:
"""Get a subagent by name.
Args:
name: Name of the subagent
Returns:
The Agent instance if found, None otherwise.
"""
return self.subagents.get(name)
def has_subagent(self, name: str) -> bool:
"""Check if a subagent exists.
Args:
name: Name of the subagent
Returns:
True if the subagent exists, False otherwise.
"""
return name in self.subagents
def remove_subagent(self, name: str) -> bool:
"""Remove a subagent.
Args:
name: Name of the subagent to remove
Returns:
True if the subagent was removed, False if it didn't exist.
"""
if name in self._fixed_subagents:
del self._fixed_subagents[name]
del self._fixed_subagent_descs[name]
return True
if name in self._dynamic_subagents:
del self._dynamic_subagents[name]
return True
return False
def list_subagents(self) -> list[str]:
"""List all subagent names.
Returns:
List of subagent names.
"""
return list(self.subagents.keys())
def __contains__(self, name: str) -> bool:
"""Check if a subagent exists using 'in' operator."""
return self.has_subagent(name)
def __getitem__(self, name: str) -> Agent:
"""Get a subagent using bracket notation."""
agent = self.get_subagent(name)
if agent is None:
raise KeyError(f"Subagent '{name}' not found")
return agent
def __iter__(self):
"""Iterate over subagent names."""
return iter(self.subagents)
def __len__(self) -> int:
"""Get the number of subagents."""
return len(self.subagents)

View File

@@ -0,0 +1,360 @@
"""Simple LLM client for direct LLM calls without agent overhead.
This module provides a simple interface for making direct LLM calls,
reusing the agentlite configuration system.
Example:
>>> from agentlite import LLMClient, AgentConfig, ProviderConfig, ModelConfig
>>>
>>> # Using configuration
>>> config = AgentConfig(
... providers={"openai": ProviderConfig(api_key="sk-...")},
... models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
... default_model="gpt4",
... )
>>> client = LLMClient(config)
>>>
>>> # Simple completion
>>> response = await client.complete(
... system_prompt="You are a helpful assistant.", user_prompt="What is Python?"
... )
>>> print(response)
>>> # Streaming
>>> async for chunk in client.stream(
... system_prompt="You are a helpful assistant.", user_prompt="Tell me a story"
... ):
... print(chunk, end="")
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Optional
from agentlite.config import AgentConfig
from agentlite.message import Message, TextPart
from agentlite.provider import ChatProvider, TokenUsage
from agentlite.providers.openai import OpenAIProvider
class LLMResponse:
"""Response from an LLM call.
Attributes:
content: The complete response text
usage: Token usage statistics
model: The model name used
"""
def __init__(self, content: str, usage: TokenUsage | None = None, model: str = ""):
self.content = content
self.usage = usage
self.model = model
def __str__(self) -> str:
return self.content
def __repr__(self) -> str:
return f"LLMResponse(content={self.content[:50]}..., model={self.model})"
class LLMClient:
"""Simple client for direct LLM calls.
This client provides a simple interface for calling LLMs without the
overhead of an Agent. It reuses the agentlite configuration system.
Example:
>>> # Using AgentConfig
>>> config = AgentConfig(...)
>>> client = LLMClient(config)
>>>
>>> # Using provider directly
>>> provider = OpenAIProvider(api_key="sk-...", model="gpt-4")
>>> client = LLMClient(provider=provider)
>>>
>>> # Make a call
>>> response = await client.complete(system_prompt="You are helpful.", user_prompt="Hello!")
"""
def __init__(
self,
config: Optional[AgentConfig] = None,
provider: Optional[ChatProvider] = None,
model: Optional[str] = None,
):
"""Initialize the LLM client.
Args:
config: AgentConfig to use for provider/model configuration
provider: Direct provider instance (alternative to config)
model: Model name to use (when using config)
Raises:
ValueError: If neither config nor provider is provided
"""
if provider is not None:
self._provider = provider
self._model_config = None
elif config is not None:
self._config = config
self._model_name = model or config.default_model
self._provider = self._create_provider()
self._model_config = config.get_model_config(self._model_name)
else:
raise ValueError("Either config or provider must be provided")
def _create_provider(self) -> ChatProvider:
"""Create a provider instance from config."""
if not hasattr(self, "_config"):
raise RuntimeError("No config available")
provider_config = self._config.get_provider_config(self._model_name)
model_config = self._config.get_model_config(self._model_name)
# Create appropriate provider based on type
if provider_config.type == "openai":
return OpenAIProvider(
api_key=provider_config.api_key.get_secret_value(),
model=model_config.model,
base_url=provider_config.base_url,
timeout=provider_config.timeout,
)
else:
raise ValueError(f"Unsupported provider type: {provider_config.type}")
async def complete(
self,
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> LLMResponse:
"""Make a non-streaming LLM call.
Args:
user_prompt: The user message/prompt
system_prompt: The system prompt (default: "You are a helpful assistant.")
temperature: Sampling temperature (overrides config if provided)
max_tokens: Maximum tokens to generate (overrides config if provided)
Returns:
LLMResponse containing the complete response text and metadata
Example:
>>> response = await client.complete(user_prompt="What is the capital of France?")
>>> print(response.content)
"The capital of France is Paris."
"""
# Build messages
messages = [Message(role="user", content=user_prompt)]
# Create a temporary provider with overridden parameters if needed
provider = self._provider
if temperature is not None or max_tokens is not None:
provider = self._create_provider_with_params(temperature, max_tokens)
# Generate response
stream = await provider.generate(
system_prompt=system_prompt,
tools=[], # No tools for simple LLM calls
history=messages,
)
# Collect response
content_parts = []
usage = None
async for part in stream:
if isinstance(part, TextPart):
content_parts.append(part.text)
# Try to get usage from stream
try:
if usage is None and hasattr(stream, "usage") and stream.usage:
usage = stream.usage
except Exception:
pass
content = "".join(content_parts)
model_name = getattr(
provider, "model_name", self._model_config.model if self._model_config else "unknown"
)
return LLMResponse(
content=content,
usage=usage,
model=model_name,
)
async def stream(
self,
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncIterator[str]:
"""Make a streaming LLM call.
Args:
user_prompt: The user message/prompt
system_prompt: The system prompt (default: "You are a helpful assistant.")
temperature: Sampling temperature (overrides config if provided)
max_tokens: Maximum tokens to generate (overrides config if provided)
Yields:
Response text chunks as they arrive
Example:
>>> async for chunk in client.stream(user_prompt="Write a poem about AI"):
... print(chunk, end="")
"""
# Build messages
messages = [Message(role="user", content=user_prompt)]
# Create a temporary provider with overridden parameters if needed
provider = self._provider
if temperature is not None or max_tokens is not None:
provider = self._create_provider_with_params(temperature, max_tokens)
# Generate response
stream = await provider.generate(
system_prompt=system_prompt,
tools=[], # No tools for simple LLM calls
history=messages,
)
# Yield chunks
async for part in stream:
if isinstance(part, TextPart):
yield part.text
def _create_provider_with_params(
self,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> ChatProvider:
"""Create a provider with overridden parameters."""
if not hasattr(self, "_config"):
# Can't override params without config
return self._provider
provider_config = self._config.get_provider_config(self._model_name)
model_config = self._config.get_model_config(self._model_name)
# Override parameters
temp = temperature if temperature is not None else model_config.temperature
max_tok = max_tokens if max_tokens is not None else model_config.max_tokens
if provider_config.type == "openai":
return OpenAIProvider(
api_key=provider_config.api_key.get_secret_value(),
model=model_config.model,
base_url=provider_config.base_url,
timeout=provider_config.timeout,
temperature=temp,
max_tokens=max_tok,
)
else:
raise ValueError(f"Unsupported provider type: {provider_config.type}")
# Convenience functions for simple use cases
async def llm_complete(
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
api_key: Optional[str] = None,
model: str = "gpt-4",
base_url: str = "https://api.openai.com/v1",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> str:
"""Simple function for one-off LLM completions.
This is a convenience function for simple use cases where you don't
need to reuse a client instance.
Args:
user_prompt: The user message/prompt
system_prompt: The system prompt
api_key: API key (if not provided, must be set in env)
model: Model name (default: gpt-4)
base_url: API base URL
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Returns:
The response text
Example:
>>> response = await llm_complete(
... user_prompt="What is 2+2?",
... api_key="sk-...",
... model="gpt-4",
... )
>>> print(response)
"2+2 equals 4."
"""
provider = OpenAIProvider(
api_key=api_key,
model=model,
base_url=base_url,
)
client = LLMClient(provider=provider)
response = await client.complete(
user_prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
)
return response.content
async def llm_stream(
user_prompt: str,
system_prompt: str = "You are a helpful assistant.",
api_key: Optional[str] = None,
model: str = "gpt-4",
base_url: str = "https://api.openai.com/v1",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> AsyncIterator[str]:
"""Simple function for one-off streaming LLM completions.
This is a convenience function for simple use cases where you don't
need to reuse a client instance.
Args:
user_prompt: The user message/prompt
system_prompt: The system prompt
api_key: API key (if not provided, must be set in env)
model: Model name (default: gpt-4)
base_url: API base URL
temperature: Sampling temperature
max_tokens: Maximum tokens to generate
Yields:
Response text chunks
Example:
>>> async for chunk in llm_stream(
... user_prompt="Write a haiku",
... api_key="sk-...",
... ):
... print(chunk, end="")
"""
provider = OpenAIProvider(
api_key=api_key,
model=model,
base_url=base_url,
)
client = LLMClient(provider=provider)
async for chunk in client.stream(
user_prompt=user_prompt,
system_prompt=system_prompt,
temperature=temperature,
max_tokens=max_tokens,
):
yield chunk

View File

@@ -0,0 +1,212 @@
"""MCP (Model Context Protocol) integration for AgentLite.
This module provides integration with MCP servers, allowing agents to use
tools from external MCP-compatible servers.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from agentlite.tool import CallableTool, ToolOk, ToolResult, ToolError
if TYPE_CHECKING:
pass
class MCPClient:
"""Client for connecting to MCP servers.
This client allows you to connect to MCP servers and load their tools
into AgentLite agents.
Example:
>>> client = MCPClient()
>>> await client.connect_stdio(
... "npx", ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
... )
>>> tools = await client.load_tools()
>>> agent = Agent(provider=provider, tools=tools)
"""
def __init__(self):
"""Initialize the MCP client."""
self._client: Any | None = None
self._connected = False
def _check_fastmcp(self) -> None:
"""Check if fastmcp is installed."""
try:
import fastmcp # noqa: F401
except ImportError as e:
raise ImportError(
"MCP support requires 'fastmcp' package. Install with: pip install agentlite[mcp]"
) from e
async def connect_stdio(
self,
command: str,
args: list[str] | None = None,
env: dict[str, str] | None = None,
) -> None:
"""Connect to an MCP server via stdio.
Args:
command: The command to run.
args: Optional arguments for the command.
env: Optional environment variables.
Raises:
RuntimeError: If already connected.
ConnectionError: If the connection fails.
"""
if self._connected:
raise RuntimeError("Already connected to an MCP server")
try:
from fastmcp import Client
from fastmcp.client.transports import PythonStdioTransport
transport = PythonStdioTransport(
command_or_script=command,
args=args or [],
env=env,
)
self._client = Client(transport)
self._connected = True
except Exception as e:
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
async def connect_sse(
self,
url: str,
headers: dict[str, str] | None = None,
) -> None:
"""Connect to an MCP server via Server-Sent Events (SSE).
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
Raises:
RuntimeError: If already connected.
ConnectionError: If the connection fails.
"""
if self._connected:
raise RuntimeError("Already connected to an MCP server")
try:
from fastmcp import Client
from fastmcp.client.transports import SSETransport
transport = SSETransport(url=url, headers=headers)
self._client = Client(transport)
self._connected = True
except Exception as e:
raise ConnectionError(f"Failed to connect to MCP server: {e}") from e
async def load_tools(self) -> list[CallableTool]:
"""Load tools from the connected MCP server.
Returns:
A list of CallableTool instances wrapping the MCP tools.
Raises:
RuntimeError: If not connected to an MCP server.
"""
if not self._connected or self._client is None:
raise RuntimeError("Not connected to an MCP server")
tools: list[CallableTool] = []
try:
async with self._client as client:
mcp_tools = await client.list_tools()
for mcp_tool in mcp_tools:
tool = _MCPTool(
client=self._client,
name=mcp_tool.name,
description=mcp_tool.description or "No description provided",
parameters=mcp_tool.inputSchema,
)
tools.append(tool)
except Exception as e:
raise RuntimeError(f"Failed to load MCP tools: {e}") from e
return tools
async def close(self) -> None:
"""Close the connection to the MCP server."""
if self._client is not None:
try:
await self._client.close()
except Exception:
pass
finally:
self._client = None
self._connected = False
async def __aenter__(self) -> MCPClient:
"""Async context manager entry."""
return self
async def __aexit__(self, *args: Any) -> None:
"""Async context manager exit."""
await self.close()
class _MCPTool(CallableTool):
"""Wrapper for MCP tools."""
def __init__(
self,
client: Any,
name: str,
description: str,
parameters: dict[str, Any],
):
"""Initialize the MCP tool wrapper.
Args:
client: The MCP client.
name: The tool name.
description: The tool description.
parameters: The JSON schema for tool parameters.
"""
self._client = client
super().__init__(
name=name,
description=description,
parameters=parameters,
)
async def __call__(self, **kwargs: Any) -> ToolResult:
"""Execute the MCP tool.
Args:
**kwargs: The tool arguments.
Returns:
The tool result.
"""
try:
async with self._client as client:
result = await client.call_tool(self.name, kwargs)
# Convert MCP result to ToolResult
content_parts = []
for content in result.content:
if hasattr(content, "text"):
content_parts.append(content.text)
else:
content_parts.append(str(content))
output = "\n".join(content_parts)
if result.isError:
return ToolError(message=output or "Tool execution failed")
return ToolOk(output=output)
except Exception as e:
return ToolError(message=f"MCP tool execution failed: {e}")

View File

@@ -0,0 +1,292 @@
"""Core message types for AgentLite.
This module defines the message and content part types used throughout
AgentLite for communication with LLM providers.
"""
from __future__ import annotations
from abc import ABC
from typing import Any, ClassVar, Literal, Optional, Union, cast
from pydantic import BaseModel, GetCoreSchemaHandler, field_validator
from pydantic_core import core_schema
Role = Literal["system", "user", "assistant", "tool"]
class MergeableMixin:
"""Mixin for content parts that can be merged during streaming."""
def merge_in_place(self, other: Any) -> bool:
"""Merge another part into this one.
Args:
other: The part to merge into this one.
Returns:
True if the merge was successful, False otherwise.
"""
return False
class ContentPart(BaseModel, ABC, MergeableMixin):
"""Base class for message content parts.
ContentPart uses a registry pattern to allow polymorphic validation
of content part subclasses based on the 'type' field.
Example:
>>> text = TextPart(text="Hello")
>>> print(text.model_dump())
{'type': 'text', 'text': 'Hello'}
"""
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
type: str
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
type_value = getattr(cls, "type", None)
if type_value is None or not isinstance(type_value, str):
raise ValueError(
f"ContentPart subclass {cls.__name__} must have a 'type' field of type str"
)
cls.__content_part_registry[type_value] = cls
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""Custom schema for polymorphic ContentPart validation."""
if cls.__name__ == "ContentPart":
def validate_content_part(value: Any) -> Any:
"""Validate a value as a ContentPart subclass."""
# Already an instance
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
return value
# Dict with type field - dispatch to subclass
if isinstance(value, dict) and "type" in value:
type_value = cast(dict[str, Any], value).get("type")
if not isinstance(type_value, str):
raise ValueError(f"Cannot validate {value} as ContentPart")
target_class = cls.__content_part_registry.get(type_value)
if target_class is None:
raise ValueError(f"Unknown content part type: {type_value}")
return target_class.model_validate(value)
raise ValueError(f"Cannot validate {value} as ContentPart")
return core_schema.no_info_plain_validator_function(validate_content_part)
# For subclasses, use default schema
return handler(source_type)
class TextPart(ContentPart):
"""Text content part.
Attributes:
text: The text content.
Example:
>>> part = TextPart(text="Hello, world!")
>>> part.model_dump()
{'type': 'text', 'text': 'Hello, world!'}
"""
type: str = "text"
text: str
def merge_in_place(self, other: Any) -> bool:
"""Merge another TextPart into this one."""
if not isinstance(other, TextPart):
return False
self.text += other.text
return True
class ImageURLPart(ContentPart):
"""Image URL content part.
Attributes:
image_url: The image URL configuration.
Example:
>>> part = ImageURLPart(
... image_url=ImageURLPart.ImageURL(url="https://example.com/image.png")
... )
"""
class ImageURL(BaseModel):
"""Image URL configuration."""
url: str
"""The URL of the image. Can be a data URI like 'data:image/png;base64,...'."""
detail: Optional[str] = None
"""The detail level: 'low', 'high', or 'auto'."""
type: str = "image_url"
image_url: ImageURL
class AudioURLPart(ContentPart):
"""Audio URL content part.
Attributes:
audio_url: The audio URL configuration.
"""
class AudioURL(BaseModel):
"""Audio URL configuration."""
url: str
"""The URL of the audio. Can be a data URI like 'data:audio/mp3;base64,...'."""
type: str = "audio_url"
audio_url: AudioURL
class ToolCall(BaseModel, MergeableMixin):
"""A tool call requested by the assistant.
Attributes:
id: Unique identifier for the tool call.
function: The function to call.
Example:
>>> call = ToolCall(
... id="call_123",
... function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}'),
... )
"""
class FunctionBody(BaseModel):
"""Function call details."""
name: str
"""The name of the tool to call."""
arguments: str
"""The arguments as a JSON string."""
type: Literal["function"] = "function"
id: str
function: FunctionBody
def merge_in_place(self, other: Any) -> bool:
"""Merge a ToolCallPart into this ToolCall."""
if not isinstance(other, ToolCallPart):
return False
if other.arguments_part:
self.function.arguments += other.arguments_part
return True
class ToolCallPart(BaseModel, MergeableMixin):
"""A partial tool call during streaming.
This represents a chunk of a tool call that is being streamed.
Attributes:
arguments_part: A chunk of the arguments JSON.
"""
arguments_part: Optional[str] = None
def merge_in_place(self, other: Any) -> bool:
"""Merge another ToolCallPart into this one."""
if not isinstance(other, ToolCallPart):
return False
if other.arguments_part:
if self.arguments_part is None:
self.arguments_part = other.arguments_part
else:
self.arguments_part += other.arguments_part
return True
class Message(BaseModel):
"""A message in a conversation.
Attributes:
role: The role of the message sender.
content: The content parts of the message.
tool_calls: Tool calls requested by the assistant (only for assistant role).
tool_call_id: The ID of the tool call being responded to (only for tool role).
name: Optional name for the sender.
Example:
>>> msg = Message(role="user", content="Hello!")
>>> print(msg.extract_text())
Hello!
"""
role: Role
content: list[ContentPart]
tool_calls: Optional[list[ToolCall]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
@field_validator("content", mode="before")
@classmethod
def _coerce_content(cls, value: Any) -> Any:
"""Coerce string content to TextPart."""
if isinstance(value, str):
return [TextPart(text=value)]
return value
def __init__(
self,
*,
role: Role,
content: Union[list[ContentPart], ContentPart, str],
tool_calls: Optional[list[ToolCall]] = None,
tool_call_id: Optional[str] = None,
name: Optional[str] = None,
) -> None:
"""Initialize a message.
Args:
role: The role of the message sender.
content: The content, can be a string, single ContentPart, or list.
tool_calls: Tool calls for assistant messages.
tool_call_id: ID of the tool call being responded to.
name: Optional name for the sender.
"""
if isinstance(content, str):
content = [TextPart(text=content)]
elif isinstance(content, ContentPart):
content = [content]
super().__init__(
role=role,
content=content,
tool_calls=tool_calls,
tool_call_id=tool_call_id,
name=name,
)
def extract_text(self, sep: str = "") -> str:
"""Extract all text from the message content.
Args:
sep: Separator to use between text parts.
Returns:
Concatenated text from all TextPart instances.
"""
return sep.join(part.text for part in self.content if isinstance(part, TextPart))
def has_tool_calls(self) -> bool:
"""Check if this message contains tool calls.
Returns:
True if the message has tool calls.
"""
return self.tool_calls is not None and len(self.tool_calls) > 0

View File

@@ -0,0 +1,161 @@
"""Chat provider protocol and implementations for AgentLite.
This module defines the ChatProvider protocol that abstracts LLM providers
and provides the base types for streaming responses.
"""
from __future__ import annotations
from collections.abc import AsyncIterator, Sequence
from typing import Protocol, Union, runtime_checkable
from pydantic import BaseModel
from agentlite.message import ContentPart, Message, ToolCall, ToolCallPart
from agentlite.tool import Tool
class TokenUsage(BaseModel):
"""Token usage statistics for a generation.
Attributes:
input_tokens: Number of input tokens used.
output_tokens: Number of output tokens generated.
cached_tokens: Number of cached input tokens (if applicable).
Example:
>>> usage = TokenUsage(input_tokens=100, output_tokens=50)
>>> print(usage.total)
150
"""
input_tokens: int
"""Number of input tokens used."""
output_tokens: int
"""Number of output tokens generated."""
cached_tokens: int = 0
"""Number of cached input tokens (if applicable)."""
@property
def total(self) -> int:
"""Total tokens used (input + output)."""
return self.input_tokens + self.output_tokens
StreamedPart = Union[ContentPart, ToolCall, ToolCallPart]
@runtime_checkable
class StreamedMessage(Protocol):
"""Protocol for streamed message responses.
This protocol defines the interface for streaming responses from LLM
providers. Implementations should yield content parts as they arrive.
Example:
>>> stream = await provider.generate(system_prompt, tools, history)
>>> async for part in stream:
... print(part)
"""
def __aiter__(self) -> AsyncIterator[StreamedPart]:
"""Return an async iterator over the streamed parts."""
...
@property
def id(self) -> str | None:
"""The unique identifier of the message, if available."""
...
@property
def usage(self) -> TokenUsage | None:
"""Token usage statistics, if available."""
...
class ChatProviderError(Exception):
"""Base exception for chat provider errors."""
def __init__(self, message: str):
super().__init__(message)
self.message = message
class APIConnectionError(ChatProviderError):
"""Error connecting to the API."""
pass
class APITimeoutError(ChatProviderError):
"""API request timed out."""
pass
class APIStatusError(ChatProviderError):
"""API returned an error status code.
Attributes:
status_code: The HTTP status code returned.
"""
def __init__(self, status_code: int, message: str):
super().__init__(message)
self.status_code = status_code
class APIEmptyResponseError(ChatProviderError):
"""API returned an empty response."""
pass
@runtime_checkable
class ChatProvider(Protocol):
"""Protocol for LLM chat providers.
This protocol defines the interface that all LLM providers must implement.
It supports both streaming and non-streaming generation.
Example:
>>> provider = OpenAIProvider(api_key="sk-...", model="gpt-4")
>>> stream = await provider.generate(
... system_prompt="You are helpful.",
... tools=[],
... history=[Message(role="user", content="Hello!")],
... )
>>> async for part in stream:
... print(part)
"""
@property
def model_name(self) -> str:
"""The name of the model being used."""
...
async def generate(
self,
system_prompt: str,
tools: Sequence[Tool],
history: Sequence[Message],
) -> StreamedMessage:
"""Generate a response from the LLM.
Args:
system_prompt: The system prompt to use.
tools: Available tools for the model to call.
history: The conversation history.
Returns:
A streamed message that yields content parts.
Raises:
APIConnectionError: If the connection fails.
APITimeoutError: If the request times out.
APIStatusError: If the API returns an error status.
APIEmptyResponseError: If the response is empty.
"""
...

View File

@@ -0,0 +1,5 @@
"""Providers package for AgentLite."""
from agentlite.providers.openai import OpenAIProvider
__all__ = ["OpenAIProvider"]

View File

@@ -0,0 +1,305 @@
"""OpenAI provider implementation for AgentLite.
This module provides an OpenAI-compatible chat provider that works with
the OpenAI API and any OpenAI-compatible API (e.g., Moonshot, Together, etc.).
"""
from __future__ import annotations
import uuid
from collections.abc import AsyncIterator, Sequence
from typing import TYPE_CHECKING, Any
import httpx
from openai import AsyncOpenAI, OpenAIError
from openai.types.chat import (
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from agentlite.message import (
Message,
TextPart,
ToolCall,
ToolCallPart,
)
from agentlite.provider import (
APIConnectionError,
APIStatusError,
APITimeoutError,
ChatProviderError,
StreamedMessage,
TokenUsage,
)
from agentlite.tool import Tool
if TYPE_CHECKING:
pass
def _convert_tool_to_openai(tool: Tool) -> ChatCompletionToolParam:
"""Convert a Tool to OpenAI tool format.
Args:
tool: The tool to convert.
Returns:
The OpenAI tool format.
"""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
},
}
def _convert_message_to_openai(message: Message) -> ChatCompletionMessageParam:
"""Convert a Message to OpenAI message format.
Args:
message: The message to convert.
Returns:
The OpenAI message format.
"""
# Start with basic message
result: dict[str, Any] = {
"role": message.role,
}
# Handle content
if message.role == "tool":
# Tool response message
result["content"] = message.extract_text()
result["tool_call_id"] = message.tool_call_id
elif message.has_tool_calls():
# Assistant message with tool calls
result["content"] = message.extract_text() or None
result["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in (message.tool_calls or [])
]
else:
# Regular message
content_parts = []
for part in message.content:
if isinstance(part, TextPart):
content_parts.append(part.text)
result["content"] = "\n".join(content_parts) if content_parts else None
return result # type: ignore[return-value]
class OpenAIStreamedMessage:
"""Streamed message implementation for OpenAI.
This class wraps the OpenAI streaming response and converts chunks
into AgentLite content parts.
"""
def __init__(self, response: AsyncIterator[ChatCompletionChunk]):
"""Initialize the streamed message.
Args:
response: The OpenAI streaming response.
"""
self._response = response
self._id: str | None = None
self._usage = TokenUsage(input_tokens=0, output_tokens=0)
def __aiter__(self) -> AsyncIterator[Any]:
"""Return an async iterator over the streamed parts."""
return self._iter_chunks()
async def _iter_chunks(self) -> AsyncIterator[Any]:
"""Iterate over response chunks and yield content parts."""
try:
async for chunk in self._response:
# Track message ID
if chunk.id:
self._id = chunk.id
# Track usage if available
if chunk.usage:
self._usage = TokenUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
)
# Skip empty choices
if not chunk.choices:
continue
delta = chunk.choices[0].delta
# Yield text content
if delta.content:
yield TextPart(text=delta.content)
# Yield tool calls
if delta.tool_calls:
for tc in delta.tool_calls:
if tc.function:
if tc.function.name:
# New tool call
yield ToolCall(
id=tc.id or str(uuid.uuid4()),
function=ToolCall.FunctionBody(
name=tc.function.name,
arguments=tc.function.arguments or "",
),
)
elif tc.function.arguments:
# Continuation of tool call arguments
yield ToolCallPart(arguments_part=tc.function.arguments)
except (OpenAIError, httpx.HTTPError) as e:
raise _convert_error(e) from e
@property
def id(self) -> str | None:
"""The unique identifier of the message."""
return self._id
@property
def usage(self) -> TokenUsage | None:
"""Token usage statistics."""
return self._usage
class OpenAIProvider:
"""OpenAI-compatible chat provider.
This provider works with the OpenAI API and any OpenAI-compatible API
such as Moonshot, Together, Fireworks, etc.
Attributes:
model: The model name to use.
client: The underlying AsyncOpenAI client.
Example:
>>> provider = OpenAIProvider(
... api_key="sk-...",
... model="gpt-4",
... )
>>> stream = await provider.generate(
... system_prompt="You are helpful.",
... tools=[],
... history=[Message(role="user", content="Hello!")],
... )
"""
def __init__(
self,
*,
api_key: str,
model: str,
base_url: str | None = None,
timeout: float = 60.0,
**client_kwargs: Any,
):
"""Initialize the OpenAI provider.
Args:
api_key: The API key for authentication.
model: The model name to use (e.g., "gpt-4", "gpt-3.5-turbo").
base_url: Optional custom base URL for OpenAI-compatible APIs.
timeout: Request timeout in seconds.
**client_kwargs: Additional arguments passed to AsyncOpenAI.
"""
self.model = model
self.client = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
timeout=timeout,
**client_kwargs,
)
@property
def model_name(self) -> str:
"""The name of the model being used."""
return self.model
async def generate(
self,
system_prompt: str,
tools: Sequence[Tool],
history: Sequence[Message],
) -> StreamedMessage:
"""Generate a response from the OpenAI API.
Args:
system_prompt: The system prompt to use.
tools: Available tools for the model to call.
history: The conversation history.
Returns:
A streamed message that yields content parts.
Raises:
APIConnectionError: If the connection fails.
APITimeoutError: If the request times out.
APIStatusError: If the API returns an error status.
APIEmptyResponseError: If the response is empty.
"""
# Build messages
messages: list[ChatCompletionMessageParam] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for msg in history:
messages.append(_convert_message_to_openai(msg))
# Build tools
openai_tools = [_convert_tool_to_openai(t) for t in tools] if tools else None
try:
# Make streaming request
response = await self.client.chat.completions.create(
model=self.model,
messages=messages,
tools=openai_tools,
stream=True,
stream_options={"include_usage": True},
)
return OpenAIStreamedMessage(response) # type: ignore[arg-type]
except (OpenAIError, httpx.HTTPError) as e:
raise _convert_error(e) from e
def _convert_error(error: OpenAIError | httpx.HTTPError) -> ChatProviderError:
"""Convert an OpenAI or HTTP error to a ChatProviderError.
Args:
error: The error to convert.
Returns:
The appropriate ChatProviderError subclass.
"""
if isinstance(error, OpenAIError):
if isinstance(error, OpenAIError.APIConnectionError):
return APIConnectionError(str(error))
elif isinstance(error, OpenAIError.APITimeoutError):
return APITimeoutError(str(error))
elif isinstance(error, OpenAIError.APIStatusError):
return APIStatusError(error.status_code, str(error))
if isinstance(error, httpx.TimeoutException):
return APITimeoutError(str(error))
elif isinstance(error, httpx.NetworkError):
return APIConnectionError(str(error))
elif isinstance(error, httpx.HTTPStatusError):
return APIStatusError(error.response.status_code, str(error))
return ChatProviderError(str(error))

View File

@@ -0,0 +1,72 @@
"""Skills system for AgentLite.
This module provides a comprehensive skill system similar to kimi-cli,
allowing agents to use modular, reusable skills defined in SKILL.md files.
Skills can be:
- **Standard**: Text-based instructions loaded as prompts
- **Flow**: Structured flowcharts (Mermaid/D2) for deterministic execution
Example:
>>> from pathlib import Path
>>> from agentlite.skills import discover_skills, SkillTool
>>> # Discover skills
>>> skills = discover_skills(Path("./skills"))
>>> skill_index = {s.name.lower(): s for s in skills}
>>> # Create skill tool
>>> skill_tool = SkillTool(skill_index, parent_agent=agent)
"""
from agentlite.skills.discovery import (
discover_skills,
discover_skills_from_roots,
get_default_skills_dirs,
index_skills_by_name,
parse_frontmatter,
parse_skill_text,
)
from agentlite.skills.flow_parser import (
FlowParseError,
parse_d2_flowchart,
parse_mermaid_flowchart,
)
from agentlite.skills.flow_runner import FlowExecutionError, FlowRunner
from agentlite.skills.models import (
Flow,
FlowEdge,
FlowNode,
FlowNodeKind,
Skill,
SkillType,
index_skills,
normalize_skill_name,
)
from agentlite.skills.skill_tool import SkillTool
__all__ = [
# Models
"Skill",
"Flow",
"FlowNode",
"FlowEdge",
"SkillType",
"FlowNodeKind",
# Discovery
"discover_skills",
"discover_skills_from_roots",
"get_default_skills_dirs",
"index_skills",
"index_skills_by_name",
"normalize_skill_name",
"parse_skill_text",
"parse_frontmatter",
# Flow parsing
"parse_mermaid_flowchart",
"parse_d2_flowchart",
"FlowParseError",
# Flow execution
"FlowRunner",
"FlowExecutionError",
# Tool
"SkillTool",
]

View File

@@ -0,0 +1,307 @@
"""Skill discovery and loading utilities for AgentLite.
This module provides functions for discovering and loading skills from
directory structures, similar to kimi-cli's skill system.
"""
from __future__ import annotations
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Optional
import yaml
if TYPE_CHECKING:
from agentlite.skills.models import Flow, Skill
def parse_frontmatter(content: str) -> Optional[Dict]:
"""Parse YAML frontmatter from markdown content.
Args:
content: The file content that may contain frontmatter
Returns:
Dictionary of frontmatter data, or None if no frontmatter found
Example:
>>> content = '''---
... name: my-skill
... description: Does something useful
... ---
... # Skill Content
... '''
>>> parse_frontmatter(content)
{'name': 'my-skill', 'description': 'Does something useful'}
"""
if not content.startswith("---"):
return None
try:
# Find the end of frontmatter
end_idx = content.find("\n---", 3)
if end_idx == -1:
return None
# Extract and parse YAML
frontmatter_text = content[3:end_idx].strip()
return yaml.safe_load(frontmatter_text) or {}
except Exception:
return None
def parse_flow_from_skill(content: str) -> "Flow":
"""Parse a flowchart from skill content.
Looks for mermaid or d2 code blocks and parses them into Flow objects.
Args:
content: The SKILL.md content containing a flowchart
Returns:
Parsed Flow object
Raises:
ValueError: If no valid flowchart found
"""
from agentlite.skills.flow_parser import (
FlowParseError,
parse_d2_flowchart,
parse_mermaid_flowchart,
)
# Extract code blocks
code_blocks = _extract_code_blocks(content)
for lang, code in code_blocks:
try:
if lang == "mermaid":
return parse_mermaid_flowchart(code)
elif lang == "d2":
return parse_d2_flowchart(code)
except FlowParseError:
continue
raise ValueError("No valid mermaid or d2 flowchart found in skill content")
def _extract_code_blocks(content: str) -> list[tuple[str, str]]:
"""Extract fenced code blocks from markdown content.
Args:
content: Markdown content
Returns:
List of (language, code) tuples
"""
blocks = []
in_block = False
current_lang = ""
current_code = []
fence_char = ""
fence_len = 0
for line in content.split("\n"):
stripped = line.lstrip()
if not in_block:
# Check for fence start
if stripped.startswith("```") or stripped.startswith("~~~"):
fence_char = stripped[0]
fence_len = len(stripped) - len(stripped.lstrip(fence_char))
if fence_len >= 3:
# Extract language
info = stripped[fence_len:].strip()
current_lang = info.split()[0] if info else ""
in_block = True
current_code = []
else:
# Check for fence end
if stripped.startswith(fence_char * fence_len):
blocks.append((current_lang, "\n".join(current_code)))
in_block = False
current_lang = ""
current_code = []
else:
current_code.append(line)
return blocks
def parse_skill_text(content: str, dir_path: Path) -> "Skill":
"""Parse skill content into a Skill object.
Args:
content: The SKILL.md content
dir_path: Path to the skill directory
Returns:
Parsed Skill object
Raises:
ValueError: If the skill content is invalid
"""
from agentlite.skills.flow_parser import FlowParseError
from agentlite.skills.models import Skill
frontmatter = parse_frontmatter(content) or {}
name = frontmatter.get("name") or dir_path.name
description = frontmatter.get("description") or "No description provided."
skill_type = frontmatter.get("type") or "standard"
if skill_type not in ("standard", "flow"):
raise ValueError(f'Invalid skill type "{skill_type}"')
# Parse flow if this is a flow-type skill
flow = None
if skill_type == "flow":
try:
flow = parse_flow_from_skill(content)
except (ValueError, FlowParseError) as e:
# Log warning and fall back to standard
import logging
logging.warning(
f"Failed to parse flow skill '{name}': {e}. Treating as standard skill."
)
skill_type = "standard"
flow = None
return Skill(
name=name,
description=description,
type=skill_type,
dir=dir_path,
flow=flow,
)
def discover_skills(skills_dir: Path) -> list["Skill"]:
"""Discover all skills in a directory.
Scans the directory for subdirectories containing SKILL.md files
and parses them into Skill objects.
Args:
skills_dir: Directory to scan for skills
Returns:
List of discovered Skill objects, sorted by name
Example:
>>> skills = discover_skills(Path("./skills"))
>>> for skill in skills:
... print(f"{skill.name}: {skill.description}")
"""
if not skills_dir.is_dir():
return []
skills: list[Skill] = []
for skill_dir in skills_dir.iterdir():
if not skill_dir.is_dir():
continue
skill_md = skill_dir / "SKILL.md"
if not skill_md.is_file():
continue
try:
content = skill_md.read_text(encoding="utf-8")
skills.append(parse_skill_text(content, skill_dir))
except Exception as e:
import logging
logging.warning(f"Failed to parse skill at {skill_md}: {e}")
continue
return sorted(skills, key=lambda s: s.name)
def discover_skills_from_roots(skills_dirs: Iterable[Path]) -> list["Skill"]:
"""Discover skills from multiple directory roots.
Skills from later directories will override skills with the same name
from earlier directories.
Args:
skills_dirs: Iterable of directories to scan
Returns:
List of unique Skill objects, sorted by name
Example:
>>> roots = [Path("./builtin"), Path("~/.config/skills").expanduser()]
>>> skills = discover_skills_from_roots(roots)
"""
from agentlite.skills.models import normalize_skill_name
skills_by_name: dict[str, "Skill"] = {}
for skills_dir in skills_dirs:
for skill in discover_skills(skills_dir):
# Later skills override earlier ones with same name
skills_by_name[normalize_skill_name(skill.name)] = skill
return sorted(skills_by_name.values(), key=lambda s: s.name)
def get_default_skills_dirs(work_dir: Path | None = None) -> list[Path]:
"""Get the default skill directory search paths.
Returns directories in priority order:
1. User-level: ~/.config/agents/skills/ (or alternatives)
2. Project-level: ./.agents/skills/ (or alternatives)
Args:
work_dir: Working directory for project-level search (default: current dir)
Returns:
List of existing skill directories
"""
dirs: list[Path] = []
# User-level candidates
user_candidates = [
Path.home() / ".config" / "agents" / "skills",
Path.home() / ".agents" / "skills",
Path.home() / ".kimi" / "skills",
]
for candidate in user_candidates:
if candidate.is_dir():
dirs.append(candidate)
break # Only use first existing
# Project-level candidates
if work_dir is None:
work_dir = Path.cwd()
project_candidates = [
work_dir / ".agents" / "skills",
work_dir / ".kimi" / "skills",
]
for candidate in project_candidates:
if candidate.is_dir():
dirs.append(candidate)
break # Only use first existing
return dirs
def index_skills_by_name(skills: Iterable["Skill"]) -> dict[str, "Skill"]:
"""Build a lookup table for skills by normalized name.
Args:
skills: Iterable of Skill objects
Returns:
Dictionary mapping normalized names to Skill objects
"""
from agentlite.skills.models import normalize_skill_name
return {normalize_skill_name(skill.name): skill for skill in skills}

View File

@@ -0,0 +1,252 @@
"""Flowchart parsers for flow-type skills.
This module provides parsers for Mermaid and D2 flowchart syntax
to convert them into Flow objects that can be executed.
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from agentlite.skills.models import Flow, FlowEdge, FlowNode
class FlowParseError(ValueError):
"""Raised when flowchart parsing fails."""
pass
def parse_mermaid_flowchart(content: str) -> "Flow":
"""Parse a Mermaid flowchart into a Flow object.
Supports basic Mermaid flowchart syntax:
- Node definitions: `id[label]`, `id(label)`, `id{label}`
- Edges: `-->`, `---`, `-.->`
- Labeled edges: `-->|label|`, `-.->|label|`
- Special nodes: BEGIN(( )), END(( ))
Args:
content: Mermaid flowchart definition
Returns:
Flow object representing the flowchart
Raises:
FlowParseError: If parsing fails
Example:
>>> mermaid = '''
... flowchart TD
... BEGIN(( )) --> CHECK[Check input]
... CHECK --> VALID{Is valid?}
... VALID -->|Yes| PROCESS[Process]
... VALID -->|No| ERROR[Show error]
... PROCESS --> END(( ))
... ERROR --> END
... '''
>>> flow = parse_mermaid_flowchart(mermaid)
"""
from agentlite.skills.models import Flow, FlowEdge, FlowNode
nodes: dict[str, FlowNode] = {}
edges: list[FlowEdge] = []
# Node patterns
# id[label] - rectangle
# id(label) - rounded
# id{label} - diamond
# id(( )) - circle (used for begin/end)
node_pattern = re.compile(
r"^(\w+)\s*" # node ID
r"(?:\[(.*?)\]|" # [label]
r"\((.*?)\)|" # (label)
r"\{(.*?)\}|" # {label}
r"\(\((.*?)\)\))" # ((label))
)
# Edge patterns
# A --> B
# A -->|label| B
# A -.-> B
edge_pattern = re.compile(
r"^(\w+)\s*" # source
r"(?:-->|---|-.->)" # arrow
r"\|([^|]*)\|?\s*" # optional label
r"(\w+)\s*$" # destination
)
for line in content.strip().split("\n"):
line = line.strip()
if not line or line.startswith("flowchart") or line.startswith("graph"):
continue
# Remove trailing punctuation
line = line.rstrip(";")
# Try to match edge first
edge_match = edge_pattern.match(line)
if edge_match:
src, label, dst = edge_match.groups()
edges.append(
FlowEdge(src=src.strip(), dst=dst.strip(), label=label.strip() if label else None)
)
continue
# Try to match node definition
node_match = node_pattern.match(line)
if node_match:
node_id = node_match.group(1)
# Get the first non-None label from groups
label = next((g for g in node_match.groups()[1:] if g is not None), node_id)
# Determine node kind
kind = "task"
if label.strip() == "" or node_id.upper() in ("BEGIN", "START"):
kind = "begin"
elif node_id.upper() in ("END", "STOP", "FINISH"):
kind = "end"
elif "{" in line or "}" in line:
kind = "decision"
nodes[node_id] = FlowNode(id=node_id, label=label, kind=kind)
# Build outgoing edge map
outgoing: dict[str, list[FlowEdge]] = {}
for edge in edges:
if edge.src not in outgoing:
outgoing[edge.src] = []
outgoing[edge.src].append(edge)
# Find begin and end nodes
begin_ids = [n.id for n in nodes.values() if n.kind == "begin"]
end_ids = [n.id for n in nodes.values() if n.kind == "end"]
if not begin_ids:
# Use first node if no explicit begin
begin_ids = [list(nodes.keys())[0]] if nodes else []
if not end_ids:
# Use last node if no explicit end
end_ids = [list(nodes.keys())[-1]] if nodes else []
if len(begin_ids) != 1:
raise FlowParseError(f"Expected exactly one BEGIN node, found {len(begin_ids)}")
if len(end_ids) != 1:
raise FlowParseError(f"Expected exactly one END node, found {len(end_ids)}")
return Flow(nodes=nodes, outgoing=outgoing, begin_id=begin_ids[0], end_id=end_ids[0])
def parse_d2_flowchart(content: str) -> "Flow":
"""Parse a D2 flowchart into a Flow object.
Supports basic D2 syntax:
- Node definitions: `id: label`
- Edges: `id1 -> id2` or `id1 -> id2: label`
- Special shapes: `id: {shape: circle}`
Args:
content: D2 flowchart definition
Returns:
Flow object representing the flowchart
Raises:
FlowParseError: If parsing fails
Example:
>>> d2 = '''
... BEGIN: {shape: circle}
... CHECK: Check input
... VALID: Is valid? {shape: diamond}
... PROCESS: Process
... ERROR: Show error
... END: {shape: circle}
...
... BEGIN -> CHECK
... CHECK -> VALID
... VALID -> PROCESS: Yes
... VALID -> ERROR: No
... PROCESS -> END
... ERROR -> END
... '''
>>> flow = parse_d2_flowchart(d2)
"""
from agentlite.skills.models import Flow, FlowEdge, FlowNode
nodes: dict[str, FlowNode] = {}
edges: list[FlowEdge] = []
# Node pattern: id: label or id: {shape: ...}
node_pattern = re.compile(r"^(\w+)\s*:\s*(.+)$")
# Edge pattern: src -> dst or src -> dst: label
edge_pattern = re.compile(r"^(\w+)\s*->\s*(\w+)(?:\s*:\s*(.+))?$")
for line in content.strip().split("\n"):
line = line.strip()
if not line:
continue
# Try edge first
edge_match = edge_pattern.match(line)
if edge_match:
src, dst, label = edge_match.groups()
edges.append(
FlowEdge(src=src.strip(), dst=dst.strip(), label=label.strip() if label else None)
)
continue
# Try node definition
node_match = node_pattern.match(line)
if node_match:
node_id, rest = node_match.groups()
rest = rest.strip()
# Check for shape definition
shape_match = re.search(r"\{shape:\s*(\w+)\}", rest)
shape = shape_match.group(1) if shape_match else None
# Extract label (remove shape definition)
label = re.sub(r"\{[^}]*\}", "", rest).strip()
if not label:
label = node_id
# Determine kind
kind = "task"
if shape == "circle" or node_id.upper() in ("BEGIN", "START"):
if not label or label == node_id:
kind = "begin"
elif node_id.upper() in ("END", "STOP"):
kind = "end"
elif shape == "diamond" or node_id.upper() in ("VALID", "CHECK", "DECISION"):
kind = "decision"
elif node_id.upper() in ("END", "STOP", "FINISH"):
kind = "end"
nodes[node_id] = FlowNode(id=node_id, label=label, kind=kind)
# Build outgoing edge map
outgoing: dict[str, list[FlowEdge]] = {}
for edge in edges:
if edge.src not in outgoing:
outgoing[edge.src] = []
outgoing[edge.src].append(edge)
# Find begin and end nodes
begin_ids = [n.id for n in nodes.values() if n.kind == "begin"]
end_ids = [n.id for n in nodes.values() if n.kind == "end"]
if not begin_ids:
begin_ids = [list(nodes.keys())[0]] if nodes else []
if not end_ids:
end_ids = [list(nodes.keys())[-1]] if nodes else []
if len(begin_ids) != 1:
raise FlowParseError(f"Expected exactly one BEGIN node, found {len(begin_ids)}")
if len(end_ids) != 1:
raise FlowParseError(f"Expected exactly one END node, found {len(end_ids)}")
return Flow(nodes=nodes, outgoing=outgoing, begin_id=begin_ids[0], end_id=end_ids[0])

View File

@@ -0,0 +1,200 @@
"""Flow runner for executing flow-type skills.
This module provides FlowRunner for executing flowchart-based skills
node by node, similar to kimi-cli's implementation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from agentlite.agent import Agent
from agentlite.skills.models import Flow, FlowEdge, FlowNode
class FlowExecutionError(Exception):
"""Raised when flow execution fails."""
pass
class FlowRunner:
"""Executes flowchart-based skills.
FlowRunner executes a flowchart node by node, handling task nodes
and decision nodes appropriately.
For task nodes: Executes the node's label as a prompt
For decision nodes: Presents options and waits for user/agent choice
Example:
>>> from agentlite.skills.models import Flow, FlowNode, FlowEdge
>>> # Define a simple flow
>>> flow = Flow(
... nodes={
... "start": FlowNode(id="start", label="Start", kind="begin"),
... "task": FlowNode(id="task", label="Analyze code", kind="task"),
... "end": FlowNode(id="end", label="End", kind="end"),
... },
... outgoing={
... "start": [FlowEdge(src="start", dst="task")],
... "task": [FlowEdge(src="task", dst="end")],
... },
... begin_id="start",
... end_id="end",
... )
>>> runner = FlowRunner(flow, "my-flow")
>>> output = await runner.run(agent, "Additional context")
"""
def __init__(self, flow: "Flow", name: str = "flow"):
"""Initialize the flow runner.
Args:
flow: The flowchart to execute
name: Name of the flow (for logging/debugging)
"""
self._flow = flow
self._name = name
async def run(self, agent: "Agent", args: str = "") -> str:
"""Execute the flow.
Args:
agent: The agent to use for executing task nodes
args: Additional arguments/context for the flow
Returns:
The combined output from all executed nodes
Raises:
FlowExecutionError: If execution fails
"""
current_id = self._flow.begin_id
outputs: list[str] = []
steps = 0
max_steps = 100 # Prevent infinite loops
while steps < max_steps:
steps += 1
node = self._flow.nodes.get(current_id)
if node is None:
raise FlowExecutionError(f"Node '{current_id}' not found in flow")
# Get outgoing edges
edges = self._flow.outgoing.get(current_id, [])
# Handle different node types
if node.kind == "end":
# Flow complete
break
elif node.kind == "begin":
# Just move to next node
if not edges:
raise FlowExecutionError("BEGIN node has no outgoing edges")
current_id = edges[0].dst
continue
elif node.kind == "task":
# Execute task
output = await self._execute_task_node(agent, node, args)
if output:
outputs.append(output)
# Move to next node
if not edges:
raise FlowExecutionError(f"Task node '{current_id}' has no outgoing edges")
current_id = edges[0].dst
elif node.kind == "decision":
# Handle decision
choice = await self._execute_decision_node(agent, node, edges, args)
# Find the edge matching the choice
next_id = None
for edge in edges:
if edge.label and edge.label.lower() == choice.lower():
next_id = edge.dst
break
if next_id is None:
raise FlowExecutionError(
f"Invalid choice '{choice}' for decision node '{current_id}'"
)
current_id = next_id
else:
raise FlowExecutionError(f"Unknown node kind: {node.kind}")
if steps >= max_steps:
raise FlowExecutionError("Flow exceeded maximum steps (possible infinite loop)")
return "\n\n".join(outputs)
async def _execute_task_node(self, agent: "Agent", node: "FlowNode", args: str) -> str:
"""Execute a task node.
Args:
agent: The agent to use
node: The task node
args: Additional arguments
Returns:
The task output
"""
# Build prompt from node label and args
prompt = node.label
if args.strip():
prompt = f"{prompt}\n\nContext: {args.strip()}"
# Execute using agent
response = await agent.run(prompt)
return response
async def _execute_decision_node(
self, agent: "Agent", node: "FlowNode", edges: list["FlowEdge"], args: str
) -> str:
"""Execute a decision node.
Args:
agent: The agent to use
node: The decision node
edges: Available outgoing edges (choices)
args: Additional arguments
Returns:
The chosen option
"""
# Build prompt with choices
choices = [edge.label for edge in edges if edge.label]
prompt_lines = [
node.label,
"",
"Available options:",
*[f"- {choice}" for choice in choices],
"",
"Reply with one of the options above.",
]
if args.strip():
prompt_lines.extend(["", f"Context: {args.strip()}"])
prompt = "\n".join(prompt_lines)
# Get choice from agent
response = await agent.run(prompt)
# Extract choice from response (find matching option)
response_clean = response.strip().lower()
for choice in choices:
if choice.lower() in response_clean or response_clean in choice.lower():
return choice
# If no exact match, return the first choice as default
# (or could raise an error)
return choices[0] if choices else ""

View File

@@ -0,0 +1,154 @@
"""Skill system for AgentLite.
This module provides a skill system similar to kimi-cli, allowing agents
to use modular, reusable skills defined in SKILL.md files.
Skills can be:
- Standard: Text-based instructions loaded as prompts
- Flow: Structured flowcharts (Mermaid/D2) for deterministic execution
Example:
>>> from agentlite.skills import Skill, discover_skills
>>> skills = discover_skills(Path("./skills"))
>>> for skill in skills:
... print(f"{skill.name}: {skill.description}")
"""
from __future__ import annotations
from collections.abc import Iterable
from pathlib import Path
from typing import Literal, Optional
from pydantic import BaseModel, Field
SkillType = Literal["standard", "flow"]
FlowNodeKind = Literal["begin", "end", "task", "decision"]
class FlowNode(BaseModel):
"""A node in a flowchart.
Attributes:
id: Unique identifier for the node
label: Display text or content for the node
kind: Type of node (begin, end, task, decision)
"""
id: str = Field(description="Unique node identifier")
label: str = Field(description="Node display text")
kind: FlowNodeKind = Field(description="Node type")
class FlowEdge(BaseModel):
"""An edge connecting two nodes in a flowchart.
Attributes:
src: Source node ID
dst: Destination node ID
label: Optional label for the edge (used for decision branches)
"""
src: str = Field(description="Source node ID")
dst: str = Field(description="Destination node ID")
label: Optional[str] = Field(default=None, description="Edge label for decisions")
class Flow(BaseModel):
"""A flowchart defining a structured workflow.
Flow skills use flowcharts to define deterministic, step-by-step
workflows that the agent executes node by node.
Attributes:
nodes: Dictionary mapping node IDs to FlowNode objects
outgoing: Dictionary mapping node IDs to their outgoing edges
begin_id: ID of the start node
end_id: ID of the end node
"""
nodes: dict[str, FlowNode] = Field(description="Node ID to node mapping")
outgoing: dict[str, list[FlowEdge]] = Field(description="Node outgoing edges")
begin_id: str = Field(description="Start node ID")
end_id: str = Field(description="End node ID")
class Skill(BaseModel):
"""A skill definition for AgentLite.
Skills are modular, reusable capabilities defined in SKILL.md files.
They can be standard (text-based) or flow-based (structured workflows).
Attributes:
name: Unique skill name
description: When and what the skill does (used for triggering)
type: Skill type - "standard" or "flow"
dir: Directory containing the skill files
flow: Flow definition (only for flow-type skills)
Example SKILL.md:
---
name: code-reviewer
description: Review code for bugs, style issues, and best practices
type: standard
---
# Code Reviewer
When reviewing code:
1. Check for syntax errors
2. Verify style guidelines
3. Suggest improvements
"""
name: str = Field(description="Unique skill name")
description: str = Field(description="Skill description and triggering criteria")
type: SkillType = Field(default="standard", description="Skill type")
dir: Path = Field(description="Skill directory path")
flow: Optional[Flow] = Field(default=None, description="Flow definition for flow-type skills")
@property
def skill_md_file(self) -> Path:
"""Path to the SKILL.md file."""
return self.dir / "SKILL.md"
def read_content(self) -> str:
"""Read the full SKILL.md content.
Returns:
The content of the SKILL.md file
Raises:
FileNotFoundError: If SKILL.md doesn't exist
"""
return self.skill_md_file.read_text(encoding="utf-8").strip()
def normalize_skill_name(name: str) -> str:
"""Normalize a skill name for lookup.
Args:
name: The skill name to normalize
Returns:
Lowercase version of the name for case-insensitive lookup
"""
return name.casefold()
def index_skills(skills: Iterable[Skill]) -> dict[str, Skill]:
"""Build a lookup table for skills by normalized name.
Args:
skills: Iterable of Skill objects
Returns:
Dictionary mapping normalized names to Skill objects
Example:
>>> skills = [Skill(name="CodeReview", ...), Skill(name="TestWriter", ...)]
>>> index = index_skills(skills)
>>> index["codereview"].name
"CodeReview"
"""
return {normalize_skill_name(skill.name): skill for skill in skills}

View File

@@ -0,0 +1,177 @@
"""Skill tool for AgentLite.
This module provides a tool for executing skills within an agent.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
if TYPE_CHECKING:
from agentlite.agent import Agent
from agentlite.skills.models import Skill
class SkillParams(BaseModel):
"""Parameters for executing a skill."""
skill_name: str = Field(description="Name of the skill to execute")
args: str = Field(default="", description="Additional arguments or context for the skill")
class SkillTool(CallableTool2[SkillParams]):
"""Tool for executing skills.
This tool allows an agent to execute skills from its skill registry.
Skills can be standard (text-based) or flow-based (structured workflows).
Example:
>>> from agentlite.skills.discovery import discover_skills
>>> from agentlite.skills.models import index_skills
>>> # Discover and index skills
>>> skills = discover_skills(Path("./skills"))
>>> skill_index = index_skills(skills)
>>> # Create skill tool
>>> skill_tool = SkillTool(skill_index, parent_agent=agent)
>>> # Execute a skill
>>> result = await skill_tool(
... {"skill_name": "code-review", "args": "Review this Python function..."}
... )
"""
name: str = "Skill"
description: str = (
"Execute a predefined skill. "
"Skills provide specialized workflows and domain knowledge. "
"Available skills are shown in the system context."
)
params: type[SkillParams] = SkillParams
def __init__(
self,
skills: dict[str, "Skill"],
parent_agent: "Agent" | None = None,
):
"""Initialize the skill tool.
Args:
skills: Dictionary mapping normalized skill names to Skill objects
parent_agent: The parent agent (used for executing skills)
"""
super().__init__()
self._skills = skills
self._parent_agent = parent_agent
async def __call__(self, params: SkillParams) -> ToolResult:
"""Execute a skill.
Args:
params: Skill execution parameters
Returns:
ToolResult with the skill output or error
"""
from agentlite.skills.models import normalize_skill_name
if not params.skill_name:
return ToolError(message="Skill name cannot be empty")
# Find the skill
normalized_name = normalize_skill_name(params.skill_name)
skill = self._skills.get(normalized_name)
if skill is None:
available = ", ".join(sorted(self._skills.keys()))
return ToolError(
message=f"Skill '{params.skill_name}' not found. Available: {available or 'none'}"
)
try:
# Execute based on skill type
if skill.type == "flow" and skill.flow is not None:
return await self._execute_flow_skill(skill, params.args)
else:
return await self._execute_standard_skill(skill, params.args)
except Exception as e:
return ToolError(message=f"Skill execution failed: {e}")
async def _execute_standard_skill(self, skill: "Skill", args: str) -> ToolResult:
"""Execute a standard (text-based) skill.
Loads the SKILL.md content and uses it as a prompt for the agent.
Args:
skill: The skill to execute
args: Additional arguments from the user
Returns:
ToolResult with the skill output
"""
# Read skill content
content = skill.read_content()
# Parse frontmatter to get just the body
from agentlite.skills.discovery import parse_frontmatter
frontmatter = parse_frontmatter(content)
# Extract body (remove frontmatter if present)
if frontmatter and content.startswith("---"):
end_idx = content.find("\n---", 3)
if end_idx != -1:
body = content[end_idx + 4 :].strip()
else:
body = content
else:
body = content
# Append user arguments if provided
if args.strip():
body = f"{body}\n\nUser request: {args.strip()}"
# Execute using parent agent if available
if self._parent_agent is not None:
# Create a temporary message with the skill content
response = await self._parent_agent.run(body)
return ToolOk(output=response, message=f"Skill '{skill.name}' executed successfully")
else:
# Return the skill content for the LLM to use
return ToolOk(
output=body, message=f"Skill '{skill.name}' loaded (no parent agent to execute)"
)
async def _execute_flow_skill(self, skill: "Skill", args: str) -> ToolResult:
"""Execute a flow-based skill.
Executes the flowchart node by node.
Args:
skill: The flow skill to execute
args: Additional arguments from the user
Returns:
ToolResult with the flow output
"""
from agentlite.skills.flow_runner import FlowRunner
if skill.flow is None:
return ToolError(message=f"Flow skill '{skill.name}' has no flow definition")
if self._parent_agent is None:
return ToolError(message="Flow skills require a parent agent to execute")
# Create flow runner and execute
runner = FlowRunner(skill.flow, skill.name)
try:
output = await runner.run(self._parent_agent, args)
return ToolOk(
output=output, message=f"Flow skill '{skill.name}' completed successfully"
)
except Exception as e:
return ToolError(message=f"Flow execution failed: {e}")

View File

@@ -0,0 +1,111 @@
"""Subagent configuration models for AgentLite.
This module provides configuration models for defining subagents
in a hierarchical agent architecture.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field, model_validator
class SubagentConfig(BaseModel):
"""Configuration for a subagent.
Subagents are child agents that can be called by a parent agent
using the Task tool. Each subagent has its own system prompt
and can optionally have its own tools.
Attributes:
name: Unique name for the subagent
description: Description of what the subagent does
system_prompt: System prompt for the subagent
system_prompt_path: Path to a file containing the system prompt
tools: List of tool paths to load (inherits from parent if not specified)
exclude_tools: Tools to exclude from parent inheritance
subagents: Nested subagents (for hierarchical structure)
max_iterations: Maximum tool call iterations for this subagent
Example:
>>> config = SubagentConfig(
... name="coder",
... description="Good at writing code",
... system_prompt="You are a coding assistant.",
... exclude_tools=["Task", "CreateSubagent"],
... )
"""
name: str = Field(description="Unique name for the subagent")
description: str = Field(description="Description of what the subagent does")
system_prompt: Optional[str] = Field(default=None, description="System prompt for the subagent")
system_prompt_path: Optional[Path] = Field(
default=None, description="Path to a file containing the system prompt"
)
tools: Optional[list[str]] = Field(
default=None,
description="List of tool import paths (e.g., 'agentlite.tools.file:ReadFile')",
)
exclude_tools: list[str] = Field(
default_factory=list, description="Tool names to exclude from parent inheritance"
)
subagents: list[SubagentConfig] = Field(
default_factory=list, description="Nested subagents (hierarchical structure)"
)
max_iterations: int = Field(
default=80, description="Maximum tool call iterations", ge=1, le=100
)
@model_validator(mode="after")
def validate_system_prompt(self) -> SubagentConfig:
"""Validate that either system_prompt or system_prompt_path is provided."""
if self.system_prompt is None and self.system_prompt_path is None:
raise ValueError("Either system_prompt or system_prompt_path must be provided")
return self
def get_system_prompt(self) -> str:
"""Get the system prompt text.
Returns:
The system prompt string.
Raises:
FileNotFoundError: If system_prompt_path is specified but file doesn't exist.
"""
if self.system_prompt is not None:
return self.system_prompt
if self.system_prompt_path is not None:
return Path(self.system_prompt_path).read_text(encoding="utf-8").strip()
raise ValueError("No system prompt available")
class SubagentSpec(BaseModel):
"""Specification for loading a subagent from a file.
This is used when subagents are defined in separate YAML files,
similar to kimi-cli's approach.
Attributes:
path: Path to the subagent configuration file
description: Description of the subagent
"""
path: Path = Field(description="Path to subagent config file")
description: str = Field(description="Description of the subagent")
def load(self) -> SubagentConfig:
"""Load the subagent configuration from the file.
Returns:
The loaded SubagentConfig.
"""
import yaml
with open(self.path, encoding="utf-8") as f:
data = yaml.safe_load(f)
return SubagentConfig(**data)

View File

@@ -0,0 +1,532 @@
"""Tool system for AgentLite.
This module provides the tool abstraction layer for defining and executing
tools that can be called by LLM agents.
"""
from __future__ import annotations
import asyncio
import inspect
import json
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Protocol,
TypeVar,
Union,
Generic,
get_type_hints,
)
import jsonschema
from pydantic import BaseModel, ValidationError
from agentlite.message import ToolCall
if TYPE_CHECKING:
pass
class ToolResult(BaseModel):
"""The result of a tool execution.
Attributes:
output: The output of the tool (string or structured data).
is_error: Whether the tool execution resulted in an error.
message: A message describing the result (for model consumption).
Example:
>>> result = ToolOk(output="42")
>>> print(result.output)
42
"""
output: str
"""The output of the tool execution."""
is_error: bool = False
"""Whether the execution resulted in an error."""
message: str = ""
"""A message describing the result (for model consumption)."""
class ToolOk(ToolResult):
"""Successful tool execution result.
Example:
>>> return ToolOk(output="File created successfully")
"""
def __init__(self, output: str, message: str = ""):
super().__init__(output=output, is_error=False, message=message or output)
class ToolError(ToolResult):
"""Failed tool execution result.
Example:
>>> return ToolError(message="File not found")
"""
def __init__(self, message: str, output: str = ""):
super().__init__(output=output or message, is_error=True, message=message)
class Tool(BaseModel):
"""Definition of a tool that can be called by the model.
Attributes:
name: The name of the tool.
description: A description of what the tool does.
parameters: JSON Schema for the tool parameters.
Example:
>>> tool = Tool(
... name="add",
... description="Add two numbers",
... parameters={
... "type": "object",
... "properties": {
... "a": {"type": "number"},
... "b": {"type": "number"},
... },
... "required": ["a", "b"],
... },
... )
"""
name: str
"""The name of the tool."""
description: str
"""A description of what the tool does."""
parameters: dict[str, Any]
"""JSON Schema for the tool parameters."""
def __init__(self, **data: Any):
super().__init__(**data)
# Validate the JSON schema
try:
jsonschema.validate(self.parameters, jsonschema.Draft202012Validator.META_SCHEMA)
except jsonschema.ValidationError as e:
raise ValueError(f"Invalid JSON schema for tool {self.name}: {e}") from e
@property
def base(self) -> "Tool":
"""Get the base Tool definition (returns self for Tool instances)."""
return self
class CallableTool(Tool, ABC):
"""Abstract base class for callable tools.
Subclasses must implement the __call__ method to define the tool's behavior.
Example:
>>> class AddTool(CallableTool):
... name = "add"
... description = "Add two numbers"
... parameters = {
... "type": "object",
... "properties": {
... "a": {"type": "number"},
... "b": {"type": "number"},
... },
... "required": ["a", "b"],
... }
...
... async def __call__(self, a: float, b: float) -> ToolResult:
... return ToolOk(output=str(a + b))
"""
@abstractmethod
async def __call__(self, *args: Any, **kwargs: Any) -> ToolResult:
"""Execute the tool.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The result of the tool execution.
"""
...
@property
def base(self) -> "Tool":
"""Get the base Tool definition."""
return Tool(
name=self.name,
description=self.description,
parameters=self.parameters,
)
async def call(self, arguments: dict[str, Any]) -> ToolResult:
"""Call the tool with validated arguments.
Args:
arguments: The arguments to pass to the tool.
Returns:
The result of the tool execution.
"""
# Validate arguments against schema
try:
jsonschema.validate(arguments, self.parameters)
except jsonschema.ValidationError as e:
return ToolError(message=f"Invalid arguments: {e}")
# Call the tool
try:
if isinstance(arguments, list):
result = await self.__call__(*arguments)
elif isinstance(arguments, dict):
result = await self.__call__(**arguments)
else:
result = await self.__call__(arguments)
if not isinstance(result, ToolResult):
return ToolError(message=f"Tool returned invalid type: {type(result)}")
return result
except Exception as e:
return ToolError(message=f"Tool execution failed: {e}")
Params = TypeVar("Params", bound=BaseModel)
class CallableTool2(ABC, Generic[Params]):
"""Type-safe callable tool using Pydantic models for parameters.
This is the preferred way to define tools as it provides full type safety
and automatic JSON schema generation.
Example:
>>> class AddParams(BaseModel):
... a: float
... b: float
>>> class AddTool(CallableTool2[AddParams]):
... name = "add"
... description = "Add two numbers"
... params = AddParams
...
... async def __call__(self, params: AddParams) -> ToolResult:
... return ToolOk(output=str(params.a + params.b))
"""
name: str
"""The name of the tool."""
description: str
"""A description of what the tool does."""
params: type[Params]
"""The Pydantic model class for parameters."""
def __init__(
self,
name: str | None = None,
description: str | None = None,
params: type[Params] | None = None,
):
cls = self.__class__
self.name = name or getattr(cls, "name", "")
if not self.name:
raise ValueError("Tool name must be provided")
self.description = description or getattr(cls, "description", "")
if not self.description:
raise ValueError("Tool description must be provided")
self.params = params or getattr(cls, "params", None)
if self.params is None:
raise ValueError("Tool params must be provided")
# Generate JSON schema from Pydantic model
self._schema = self.params.model_json_schema()
@property
def base(self) -> Tool:
"""Get the base Tool definition."""
return Tool(
name=self.name,
description=self.description,
parameters=self._schema,
)
@abstractmethod
async def __call__(self, params: Params) -> ToolResult:
"""Execute the tool.
Args:
params: The validated parameters.
Returns:
The result of the tool execution.
"""
...
async def call(self, arguments: dict[str, Any]) -> ToolResult:
"""Call the tool with validated arguments.
Args:
arguments: The arguments to validate and pass to the tool.
Returns:
The result of the tool execution.
"""
try:
params = self.params.model_validate(arguments)
except ValidationError as e:
return ToolError(message=f"Invalid arguments: {e}")
try:
result = await self.__call__(params)
if not isinstance(result, ToolResult):
return ToolError(message=f"Tool returned invalid type: {type(result)}")
return result
except Exception as e:
return ToolError(message=f"Tool execution failed: {e}")
class Toolset(Protocol):
"""Protocol for tool collections.
A Toolset manages a collection of tools and handles tool calls.
"""
@property
def tools(self) -> list[Tool]:
"""Get all tool definitions."""
...
def handle(self, tool_call: ToolCall) -> "ToolResult | asyncio.Future[ToolResult]":
"""Handle a tool call.
Args:
tool_call: The tool call to handle.
Returns:
The tool result or a future that resolves to the result.
"""
...
ToolType = Union[CallableTool, CallableTool2[Any]]
class SimpleToolset:
"""A simple in-memory toolset.
This is the default toolset implementation that stores tools in a dictionary
and executes them concurrently.
Example:
>>> toolset = SimpleToolset()
>>> toolset.add(MyTool())
>>> result = await toolset.handle(tool_call)
"""
def __init__(self, tools: Iterable[ToolType] | None = None):
"""Initialize the toolset.
Args:
tools: Optional initial tools to add.
"""
self._tools: dict[str, ToolType] = {}
if tools:
for tool in tools:
self.add(tool)
def add(self, tool: ToolType) -> "SimpleToolset":
"""Add a tool to the toolset.
Args:
tool: The tool to add.
Returns:
Self for chaining.
Raises:
ValueError: If a tool with the same name already exists.
"""
if tool.name in self._tools:
raise ValueError(f"Tool '{tool.name}' already exists")
self._tools[tool.name] = tool
return self
def remove(self, name: str) -> "SimpleToolset":
"""Remove a tool from the toolset.
Args:
name: The name of the tool to remove.
Returns:
Self for chaining.
Raises:
KeyError: If the tool doesn't exist.
"""
if name not in self._tools:
raise KeyError(f"Tool '{name}' not found")
del self._tools[name]
return self
def get(self, name: str) -> ToolType | None:
"""Get a tool by name.
Args:
name: The name of the tool.
Returns:
The tool if found, None otherwise.
"""
return self._tools.get(name)
def __contains__(self, name: str) -> bool:
"""Check if a tool exists in the toolset."""
return name in self._tools
def __len__(self) -> int:
"""Get the number of tools in the toolset."""
return len(self._tools)
@property
def tools(self) -> list[Tool]:
"""Get all tool definitions."""
result = []
for tool in self._tools.values():
if isinstance(tool, CallableTool):
result.append(
Tool(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
)
else:
result.append(tool.base)
return result
def handle(self, tool_call: ToolCall) -> "asyncio.Future[ToolResult]":
"""Handle a tool call.
Args:
tool_call: The tool call to handle.
Returns:
A future that resolves to the tool result.
"""
tool = self._tools.get(tool_call.function.name)
if tool is None:
future: asyncio.Future[ToolResult] = asyncio.get_event_loop().create_future()
future.set_result(ToolError(message=f"Tool '{tool_call.function.name}' not found"))
return future
# Parse arguments
try:
arguments = json.loads(tool_call.function.arguments or "{}")
except json.JSONDecodeError as e:
future = asyncio.get_event_loop().create_future()
future.set_result(ToolError(message=f"Invalid JSON arguments: {e}"))
return future
# Execute tool
async def _execute() -> ToolResult:
try:
return await tool.call(arguments)
except Exception as e:
return ToolError(message=f"Tool execution failed: {e}")
return asyncio.create_task(_execute())
def tool(
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[Callable[..., Any]], CallableTool]:
"""Decorator to convert a function into a tool.
This decorator automatically generates the JSON schema from the function's
type hints and docstring.
Args:
name: Optional tool name (defaults to function name).
description: Optional description (defaults to function docstring).
Returns:
A decorator that converts the function into a CallableTool.
Example:
>>> @tool()
... async def add(a: float, b: float) -> float:
... '''Add two numbers.'''
... return a + b
>>> agent = Agent(tools=[add])
"""
def decorator(func: callable) -> CallableTool:
sig = inspect.signature(func)
try:
type_hints = get_type_hints(func)
except Exception:
type_hints = {}
properties = {}
required = []
for param_name, param in sig.parameters.items():
if param.default is inspect.Parameter.empty:
required.append(param_name)
param_type = type_hints.get(param_name, param.annotation)
if param_type is inspect.Parameter.empty or param_type is None:
param_type = str
# Map Python types to JSON schema types
if param_type in (str,):
properties[param_name] = {"type": "string"}
elif param_type in (int,):
properties[param_name] = {"type": "integer"}
elif param_type in (float,):
properties[param_name] = {"type": "number"}
elif param_type in (bool,):
properties[param_name] = {"type": "boolean"}
else:
properties[param_name] = {"type": "string"}
parameters = {
"type": "object",
"properties": properties,
}
if required:
parameters["required"] = required
# Create tool class
tool_name = name or func.__name__
tool_description = description or (func.__doc__ or "No description provided")
tool_parameters = parameters
class FunctionTool(CallableTool):
name: str = tool_name
description: str = tool_description
parameters: dict[str, Any] = tool_parameters
async def __call__(self, *args: Any, **kwargs: Any) -> ToolResult:
try:
result = await func(*args, **kwargs)
return ToolOk(output=str(result))
except Exception as e:
return ToolError(message=str(e))
return FunctionTool()
return decorator

View File

@@ -0,0 +1,208 @@
"""Tool suite for AgentLite - A collection of tools inspired by kimi-cli.
This module provides a comprehensive set of tools for file operations,
shell execution, web access, and more, with configuration support
for enabling/disabling individual tools.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
from agentlite.tool import SimpleToolset
from agentlite.tools.config import (
ToolSuiteConfig,
FileToolsConfig,
ShellToolsConfig,
WebToolsConfig,
MultiAgentToolsConfig,
ToolGroupConfig,
)
# Import tool implementations
from agentlite.tools.file.read import ReadFile
from agentlite.tools.file.write import WriteFile
from agentlite.tools.file.replace import StrReplaceFile
from agentlite.tools.file.glob import Glob
from agentlite.tools.file.grep import Grep
from agentlite.tools.file.read_media import ReadMediaFile
from agentlite.tools.shell.shell import Shell
from agentlite.tools.web.fetch import FetchURL
from agentlite.tools.misc.todo import SetTodoList
from agentlite.tools.misc.think import Think
class ConfigurableToolset(SimpleToolset):
"""A toolset that supports configuration-based tool enabling/disabling.
This toolset loads tools based on a ToolSuiteConfig, only adding
tools that are enabled in the configuration.
Example:
>>> config = ToolSuiteConfig(
... file_tools=FileToolsConfig(
... tools={"WriteFile": False} # Disable WriteFile
... )
... )
>>> toolset = ConfigurableToolset(config)
>>> "ReadFile" in toolset # True
True
>>> "WriteFile" in toolset # False
False
"""
def __init__(self, config: ToolSuiteConfig | None = None, work_dir: Optional[str] = None):
"""Initialize the configurable toolset.
Args:
config: Tool suite configuration. If None, uses default config (all enabled).
work_dir: Working directory for file operations. Defaults to current directory.
"""
super().__init__()
self.config = config or ToolSuiteConfig()
self.work_dir = Path(work_dir) if work_dir else Path.cwd()
self._load_tools()
def _load_tools(self) -> None:
"""Load tools based on configuration."""
enabled = self.config.get_enabled_tools()
# File tools
if "file" in enabled:
self._load_file_tools(enabled["file"])
# Shell tools
if "shell" in enabled:
self._load_shell_tools(enabled["shell"])
# Web tools
if "web" in enabled:
self._load_web_tools(enabled["web"])
# Multi-agent tools
if "multiagent" in enabled:
self._load_multiagent_tools(enabled["multiagent"])
# Misc tools
if "misc" in enabled:
self._load_misc_tools(enabled["misc"])
def _load_file_tools(self, tool_names: list[str]) -> None:
"""Load file operation tools."""
cfg = self.config.file_tools
if "ReadFile" in tool_names:
self.add(
ReadFile(
work_dir=self.work_dir,
max_lines=cfg.max_lines,
max_line_length=cfg.max_line_length,
max_bytes=cfg.max_bytes,
)
)
if "WriteFile" in tool_names:
self.add(
WriteFile(
work_dir=self.work_dir, allow_outside_work_dir=cfg.allow_write_outside_work_dir
)
)
if "StrReplaceFile" in tool_names:
self.add(
StrReplaceFile(
work_dir=self.work_dir, allow_outside_work_dir=cfg.allow_write_outside_work_dir
)
)
if "Glob" in tool_names:
self.add(Glob(work_dir=self.work_dir, max_matches=cfg.max_glob_matches))
if "Grep" in tool_names:
self.add(Grep(work_dir=self.work_dir))
if "ReadMediaFile" in tool_names:
self.add(ReadMediaFile(work_dir=self.work_dir))
def _load_shell_tools(self, tool_names: list[str]) -> None:
"""Load shell execution tools."""
cfg = self.config.shell_tools
if "Shell" in tool_names:
self.add(
Shell(
timeout=cfg.timeout,
max_timeout=cfg.max_timeout,
blocked_commands=cfg.blocked_commands,
)
)
def _load_web_tools(self, tool_names: list[str]) -> None:
"""Load web-related tools."""
cfg = self.config.web_tools
if "FetchURL" in tool_names:
self.add(
FetchURL(
timeout=cfg.timeout,
user_agent=cfg.user_agent,
max_content_length=cfg.max_content_length,
)
)
def _load_multiagent_tools(self, tool_names: list[str]) -> None:
"""Load multi-agent tools."""
# Multi-agent tools are intentionally disabled in this submodule
# because nested subagents are not supported in subagent runtime.
return
def _load_misc_tools(self, tool_names: list[str]) -> None:
"""Load miscellaneous tools."""
if "SetTodoList" in tool_names:
self.add(SetTodoList())
if "Think" in tool_names:
self.add(Think())
def reload(self, config: ToolSuiteConfig | None = None) -> None:
"""Reload tools with a new configuration.
Args:
config: New configuration. If None, reloads with current config.
"""
if config:
self.config = config
# Clear existing tools
self._tools.clear()
# Reload
self._load_tools()
# Convenience exports
__all__ = [
# Toolset
"ConfigurableToolset",
# Config classes
"ToolSuiteConfig",
"FileToolsConfig",
"ShellToolsConfig",
"WebToolsConfig",
"MultiAgentToolsConfig",
"ToolGroupConfig",
# Tools
"ReadFile",
"WriteFile",
"StrReplaceFile",
"Glob",
"Grep",
"ReadMediaFile",
"Shell",
"FetchURL",
"SetTodoList",
"Think",
]

View File

@@ -0,0 +1,242 @@
"""Tool group configuration system for AgentLite.
This module provides configuration management for tool groups,
allowing users to enable/disable specific tools.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
class ToolGroupConfig(BaseModel):
"""Configuration for a group of tools.
This configuration allows users to enable or disable specific tools
within the tool group. All tools are enabled by default.
Example:
>>> config = ToolGroupConfig(
... enabled=True,
... tools={
... "ReadFile": True,
... "WriteFile": False, # Disabled
... },
... )
"""
enabled: bool = Field(default=True, description="Whether the entire tool group is enabled")
tools: dict[str, bool] = Field(
default_factory=dict,
description="Individual tool enable/disable settings. True=enabled, False=disabled. "
"Tools not listed here follow the default behavior (enabled).",
)
default_tool_enabled: bool = Field(
default=True, description="Default state for tools not explicitly listed in 'tools' dict"
)
def is_tool_enabled(self, tool_name: str) -> bool:
"""Check if a specific tool is enabled.
Args:
tool_name: The name of the tool to check
Returns:
True if the tool is enabled, False otherwise
"""
if not self.enabled:
return False
# Check explicit setting
if tool_name in self.tools:
return self.tools[tool_name]
# Use default
return self.default_tool_enabled
def enable_tool(self, tool_name: str) -> None:
"""Enable a specific tool.
Args:
tool_name: The name of the tool to enable
"""
self.tools[tool_name] = True
def disable_tool(self, tool_name: str) -> None:
"""Disable a specific tool.
Args:
tool_name: The name of the tool to disable
"""
self.tools[tool_name] = False
def set_tool_state(self, tool_name: str, enabled: bool) -> None:
"""Set the enabled state of a specific tool.
Args:
tool_name: The name of the tool
enabled: True to enable, False to disable
"""
self.tools[tool_name] = enabled
class FileToolsConfig(ToolGroupConfig):
"""Configuration for file operation tools."""
max_lines: int = Field(
default=1000, description="Maximum number of lines to read from a file", ge=1, le=10000
)
max_line_length: int = Field(
default=2000, description="Maximum length of a single line", ge=100, le=10000
)
max_bytes: int = Field(
default=100 * 1024, # 100KB
description="Maximum bytes to read from a file",
ge=1024,
le=10 * 1024 * 1024, # 10MB
)
allow_write_outside_work_dir: bool = Field(
default=False, description="Allow writing files outside the working directory"
)
max_glob_matches: int = Field(
default=1000, description="Maximum number of glob matches to return", ge=1, le=10000
)
class ShellToolsConfig(ToolGroupConfig):
"""Configuration for shell execution tools."""
timeout: int = Field(
default=60, description="Default timeout for shell commands in seconds", ge=1, le=3600
)
max_timeout: int = Field(
default=300, description="Maximum allowed timeout for shell commands", ge=1, le=3600
)
blocked_commands: list[str] = Field(
default_factory=list, description="List of command patterns to block"
)
class WebToolsConfig(ToolGroupConfig):
"""Configuration for web-related tools."""
timeout: int = Field(
default=30, description="Timeout for HTTP requests in seconds", ge=1, le=300
)
user_agent: str = Field(
default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
description="User-Agent string for HTTP requests",
)
max_content_length: int = Field(
default=1024 * 1024, # 1MB
description="Maximum content length to fetch",
ge=1024,
le=10 * 1024 * 1024, # 10MB
)
class MultiAgentToolsConfig(ToolGroupConfig):
"""Configuration for multi-agent tools."""
enabled: bool = Field(
default=False, description="Whether multi-agent tools are enabled. Disabled by default for subagent mode."
)
max_steps: int = Field(
default=50, description="Maximum steps for subagent execution", ge=1, le=1000
)
inherit_context: bool = Field(
default=False, description="Whether subagents inherit parent context"
)
class ToolSuiteConfig(BaseModel):
"""Complete configuration for all tool groups.
This is the main configuration class that aggregates all tool group configs.
Example:
>>> config = ToolSuiteConfig(
... file_tools=FileToolsConfig(tools={"WriteFile": False}),
... shell_tools=ShellToolsConfig(
... enabled=False # Disable all shell tools
... ),
... )
"""
file_tools: FileToolsConfig = Field(
default_factory=FileToolsConfig, description="File operation tools configuration"
)
shell_tools: ShellToolsConfig = Field(
default_factory=ShellToolsConfig, description="Shell execution tools configuration"
)
web_tools: WebToolsConfig = Field(
default_factory=WebToolsConfig, description="Web-related tools configuration"
)
multiagent_tools: MultiAgentToolsConfig = Field(
default_factory=MultiAgentToolsConfig, description="Multi-agent tools configuration"
)
misc_tools: ToolGroupConfig = Field(
default_factory=ToolGroupConfig,
description="Miscellaneous tools (todo, think, etc.) configuration",
)
def get_enabled_tools(self) -> dict[str, list[str]]:
"""Get a mapping of tool group names to their enabled tools.
Returns:
Dictionary mapping tool group names to lists of enabled tool names
"""
result: dict[str, list[str]] = {}
# File tools
if self.file_tools.enabled:
file_tools = [
"ReadFile",
"WriteFile",
"StrReplaceFile",
"Glob",
"Grep",
"ReadMediaFile",
]
result["file"] = [t for t in file_tools if self.file_tools.is_tool_enabled(t)]
# Shell tools
if self.shell_tools.enabled:
shell_tools = ["Shell"]
result["shell"] = [t for t in shell_tools if self.shell_tools.is_tool_enabled(t)]
# Web tools
if self.web_tools.enabled:
web_tools = ["FetchURL"]
result["web"] = [t for t in web_tools if self.web_tools.is_tool_enabled(t)]
# Multi-agent tools
if self.multiagent_tools.enabled:
multi_tools = ["Task", "CreateSubagent"]
result["multiagent"] = [
t for t in multi_tools if self.multiagent_tools.is_tool_enabled(t)
]
# Misc tools
if self.misc_tools.enabled:
misc_tools = ["SetTodoList", "Think"]
result["misc"] = [t for t in misc_tools if self.misc_tools.is_tool_enabled(t)]
return result

View File

@@ -0,0 +1,20 @@
"""File operation tools for AgentLite.
This module provides tools for reading, writing, and manipulating files.
"""
from agentlite.tools.file.read import ReadFile
from agentlite.tools.file.write import WriteFile
from agentlite.tools.file.replace import StrReplaceFile
from agentlite.tools.file.glob import Glob
from agentlite.tools.file.grep import Grep
from agentlite.tools.file.read_media import ReadMediaFile
__all__ = [
"ReadFile",
"WriteFile",
"StrReplaceFile",
"Glob",
"Grep",
"ReadMediaFile",
]

View File

@@ -0,0 +1,154 @@
"""Glob tool for AgentLite.
This module provides a tool for searching files using glob patterns.
"""
from __future__ import annotations
from typing import Optional
from pathlib import Path
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the Glob tool."""
pattern: str = Field(
description="Glob pattern to match files/directories (e.g., '*.py', '**/*.txt')"
)
directory: Optional[str] = Field(
description=(
"Absolute path to the directory to search in (defaults to working directory)."
),
default=None,
)
include_dirs: bool = Field(
description="Whether to include directories in results.",
default=True,
)
class Glob(CallableTool2[Params]):
"""Tool for searching files using glob patterns.
This tool finds files and directories matching a glob pattern.
Supports recursive patterns with **.
Example:
>>> tool = Glob(work_dir=Path("/tmp"))
>>> result = await tool({"pattern": "*.py"})
"""
name: str = "Glob"
description: str = (
"Search for files and directories matching a glob pattern. "
"Supports recursive patterns with **. "
"Returns paths relative to the search directory."
)
params: type[Params] = Params
def __init__(
self,
work_dir: Path,
max_matches: int = 1000,
):
"""Initialize the Glob tool.
Args:
work_dir: The working directory for relative paths
max_matches: Maximum number of matches to return
"""
super().__init__()
self._work_dir = work_dir
self._max_matches = max_matches
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
async def __call__(self, params: Params) -> ToolResult:
"""Execute the glob search.
Args:
params: The search parameters
Returns:
ToolResult with matching paths or error
"""
try:
# Determine search directory
if params.directory:
search_dir = Path(params.directory).expanduser().resolve()
if not search_dir.is_absolute():
return ToolError(
message=f"Directory must be an absolute path: {params.directory}",
)
# Security check
if not self._is_within_work_dir(search_dir):
return ToolError(
message=(
f"Directory `{params.directory}` is outside the working directory. "
"You can only search within the working directory."
),
)
else:
search_dir = self._work_dir
# Check directory exists
if not search_dir.exists():
return ToolError(
message=f"Directory `{search_dir}` does not exist.",
)
if not search_dir.is_dir():
return ToolError(
message=f"`{search_dir}` is not a directory.",
)
# Security check: prevent ** patterns at the root level
if params.pattern.startswith("**") and not params.directory:
return ToolError(
message=(
f"Pattern `{params.pattern}` starts with '**' which is not allowed "
"without specifying a directory. This would recursively search all "
"directories and may include large directories like `node_modules`. "
"Use a more specific pattern or provide a directory."
),
)
# Perform glob search
matches = list(search_dir.glob(params.pattern))
# Filter directories if not requested
if not params.include_dirs:
matches = [p for p in matches if p.is_file()]
# Sort for consistent output
matches.sort()
# Limit matches
truncated = False
if len(matches) > self._max_matches:
matches = matches[: self._max_matches]
truncated = True
# Format output (relative to search directory)
output = "\n".join(str(p.relative_to(search_dir)) for p in matches)
# Build message
message = f"Found {len(matches)} matches for pattern `{params.pattern}`."
if truncated:
message += f" Only the first {self._max_matches} matches are returned."
return ToolOk(output=output, message=message)
except Exception as e:
return ToolError(
message=f"Failed to search for pattern `{params.pattern}`. Error: {e}",
)

View File

@@ -0,0 +1,303 @@
"""Grep tool for AgentLite.
This module provides a tool for searching file contents using regex patterns.
"""
from __future__ import annotations
from typing import Optional
import re
from pathlib import Path
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the Grep tool."""
pattern: str = Field(
description="The regular expression pattern to search for in file contents"
)
path: str = Field(
description=(
"File or directory to search in. Defaults to current working directory. "
"If specified, it must be an absolute path."
),
default=".",
)
glob: Optional[str] = Field(
description=(
"Glob pattern to filter files (e.g. `*.py`, `*.{ts,tsx}`). No filter by default."
),
default=None,
)
output_mode: str = Field(
description=(
"`content`: Show matching lines (supports `-B`, `-A`, `-C`, `-n`); "
"`files_with_matches`: Show file paths; "
"`count_matches`: Show total number of matches. "
"Defaults to `files_with_matches`."
),
default="files_with_matches",
)
before_context: Optional[int] = Field(
description=(
"Number of lines to show before each match (the `-B` option). "
"Requires `output_mode` to be `content`."
),
default=None,
)
after_context: Optional[int] = Field(
description=(
"Number of lines to show after each match (the `-A` option). "
"Requires `output_mode` to be `content`."
),
default=None,
)
context: Optional[int] = Field(
description=(
"Number of lines to show before and after each match (the `-C` option). "
"Requires `output_mode` to be `content`."
),
default=None,
)
line_number: bool = Field(
description=(
"Show line numbers in output (the `-n` option). Requires `output_mode` to be `content`."
),
default=False,
)
ignore_case: bool = Field(
description="Case insensitive search (the `-i` option).",
default=False,
)
class Grep(CallableTool2[Params]):
"""Tool for searching file contents using regex patterns.
This tool searches file contents for matches to a regex pattern.
Supports various output modes and context options.
Example:
>>> tool = Grep(work_dir=Path("/tmp"))
>>> result = await tool({"pattern": "def ", "glob": "*.py"})
"""
name: str = "Grep"
description: str = (
"Search file contents using regular expressions. "
"Supports various output modes and context options. "
"Can search individual files or entire directories."
)
params: type[Params] = Params
def __init__(
self,
work_dir: Path,
):
"""Initialize the Grep tool.
Args:
work_dir: The working directory
"""
super().__init__()
self._work_dir = work_dir
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
def _search_file(
self,
file_path: Path,
pattern: re.Pattern,
params: Params,
) -> list[tuple[int, str]]:
"""Search a single file for matches.
Args:
file_path: Path to the file
pattern: Compiled regex pattern
params: Search parameters
Returns:
List of (line_number, line_content) tuples
"""
try:
content = file_path.read_text(encoding="utf-8", errors="replace")
except Exception:
return []
lines = content.split("\n")
matches = []
for i, line in enumerate(lines, 1):
if pattern.search(line):
matches.append((i, line))
return matches
def _format_matches(
self,
matches: dict[Path, list[tuple[int, str]]],
params: Params,
) -> str:
"""Format matches according to output mode.
Args:
matches: Dict of file_path -> list of (line_num, line) tuples
params: Output parameters
Returns:
Formatted output string
"""
if params.output_mode == "files_with_matches":
return "\n".join(str(p) for p in sorted(matches.keys()))
if params.output_mode == "count_matches":
total = sum(len(m) for m in matches.values())
return f"Total matches: {total}"
# content mode
output_lines = []
for file_path in sorted(matches.keys()):
file_matches = matches[file_path]
# Read file for context
try:
content = file_path.read_text(encoding="utf-8", errors="replace")
lines = content.split("\n")
except Exception:
continue
# Determine context lines
before = params.context if params.context else params.before_context or 0
after = params.context if params.context else params.after_context or 0
# Track which lines to include (to avoid duplicates)
included_lines = set()
for match_line_num, _ in file_matches:
start = max(1, match_line_num - before)
end = min(len(lines), match_line_num + after)
for i in range(start, end + 1):
included_lines.add(i)
# Build output for this file
if output_lines:
output_lines.append("")
output_lines.append(f"File: {file_path}")
prev_line = 0
for line_num in sorted(included_lines):
# Add separator if there's a gap
if prev_line and line_num > prev_line + 1:
output_lines.append("--")
line = lines[line_num - 1]
prefix = f"{line_num}:" if params.line_number else ""
output_lines.append(f"{prefix}{line}")
prev_line = line_num
return "\n".join(output_lines)
async def __call__(self, params: Params) -> ToolResult:
"""Execute the grep search.
Args:
params: The search parameters
Returns:
ToolResult with search results or error
"""
try:
# Resolve path
if params.path == ".":
search_path = self._work_dir
else:
search_path = Path(params.path).expanduser().resolve()
if not search_path.is_absolute():
return ToolError(
message=f"Path must be an absolute path: {params.path}",
)
# Security check
if not self._is_within_work_dir(search_path):
return ToolError(
message=(
f"Path `{params.path}` is outside the working directory. "
"You can only search within the working directory."
),
)
# Check path exists
if not search_path.exists():
return ToolError(
message=f"Path `{params.path}` does not exist.",
)
# Compile pattern
flags = re.IGNORECASE if params.ignore_case else 0
try:
pattern = re.compile(params.pattern, flags)
except re.error as e:
return ToolError(
message=f"Invalid regex pattern: {e}",
)
# Find files to search
if search_path.is_file():
files = [search_path]
else:
if params.glob:
files = list(search_path.glob(params.glob))
else:
# Default: search all files recursively (with some exclusions)
files = [
p
for p in search_path.rglob("*")
if p.is_file()
and not any(
part.startswith(".") or part in ("node_modules", "__pycache__", ".git")
for part in p.parts
)
]
# Filter to text files only
files = [p for p in files if p.is_file()]
# Search files
all_matches: dict[Path, list[tuple[int, str]]] = {}
for file_path in files:
matches = self._search_file(file_path, pattern, params)
if matches:
all_matches[file_path] = matches
# Format output
output = self._format_matches(all_matches, params)
# Build message
total_files = len(all_matches)
total_matches = sum(len(m) for m in all_matches.values())
if params.output_mode == "files_with_matches":
message = f"Found matches in {total_files} file(s)."
elif params.output_mode == "count_matches":
message = f"Found {total_matches} total match(es)."
else:
message = f"Found {total_matches} match(es) in {total_files} file(s)."
return ToolOk(output=output, message=message)
except Exception as e:
return ToolError(
message=f"Failed to search. Error: {e}",
)

View File

@@ -0,0 +1,207 @@
"""ReadFile tool for AgentLite.
This module provides a tool for reading text files with line numbers.
"""
from __future__ import annotations
from pathlib import Path
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the ReadFile tool."""
path: str = Field(
description=(
"The path to the file to read. Absolute paths are required when reading files "
"outside the working directory."
)
)
line_offset: int = Field(
description=(
"The line number to start reading from. "
"By default read from the beginning of the file. "
"Set this when the file is too large to read at once."
),
default=1,
ge=1,
)
n_lines: int = Field(
description=(
"The number of lines to read. "
"By default read up to max_lines lines. "
"Set this value when the file is too large to read at once."
),
default=1000,
ge=1,
)
class ReadFile(CallableTool2[Params]):
"""Tool for reading text files with line numbers.
This tool reads a text file and returns its contents with line numbers.
It supports pagination for large files.
Example:
>>> tool = ReadFile(work_dir=Path("/tmp"))
>>> result = await tool({"path": "/tmp/test.txt"})
"""
name: str = "ReadFile"
description: str = (
"Read a text file from the local filesystem. "
"Returns the file content with line numbers. "
"Supports reading specific line ranges for large files."
)
params: type[Params] = Params
def __init__(
self,
work_dir: Path,
max_lines: int = 1000,
max_line_length: int = 2000,
max_bytes: int = 100 * 1024,
):
"""Initialize the ReadFile tool.
Args:
work_dir: The working directory for relative paths
max_lines: Maximum number of lines to read
max_line_length: Maximum length of a single line
max_bytes: Maximum bytes to read from a file
"""
super().__init__()
self._work_dir = work_dir
self._max_lines = max_lines
self._max_line_length = max_line_length
self._max_bytes = max_bytes
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
async def __call__(self, params: Params) -> ToolResult:
"""Execute the read file operation.
Args:
params: The read parameters
Returns:
ToolResult with the file content or error
"""
if not params.path:
return ToolError(
message="File path cannot be empty.",
)
try:
# Resolve path
path = Path(params.path).expanduser()
if not path.is_absolute():
path = self._work_dir / path
path = path.resolve()
# Security check: if outside work_dir, must be absolute path
if not self._is_within_work_dir(path) and not Path(params.path).is_absolute():
return ToolError(
message=(
f"`{params.path}` is not an absolute path. "
"You must provide an absolute path to read a file "
"outside the working directory."
),
)
# Check file exists
if not path.exists():
return ToolError(
message=f"`{params.path}` does not exist.",
)
if not path.is_file():
return ToolError(
message=f"`{params.path}` is not a file.",
)
# Read file content
try:
content = path.read_text(encoding="utf-8", errors="replace")
except UnicodeDecodeError:
return ToolError(
message=f"`{params.path}` appears to be a binary file and cannot be read as text.",
)
# Split into lines
lines = content.split("\n")
# Apply line offset
start_idx = params.line_offset - 1
if start_idx >= len(lines):
return ToolOk(
output="",
message=f"Line offset {params.line_offset} exceeds file length ({len(lines)} lines).",
)
# Calculate end index
end_idx = min(start_idx + params.n_lines, len(lines))
end_idx = min(end_idx, start_idx + self._max_lines)
# Extract lines
selected_lines = lines[start_idx:end_idx]
# Truncate long lines and count total bytes
truncated_lines = []
truncated_line_numbers = []
total_bytes = 0
max_bytes_reached = False
for i, line in enumerate(selected_lines):
line_num = start_idx + i + 1
# Truncate if needed
if len(line) > self._max_line_length:
line = line[: self._max_line_length]
truncated_line_numbers.append(line_num)
# Check bytes limit
line_bytes = len(line.encode("utf-8"))
if total_bytes + line_bytes > self._max_bytes:
max_bytes_reached = True
break
total_bytes += line_bytes
truncated_lines.append(line)
# Format with line numbers
lines_with_no = []
for line_num, line in enumerate(truncated_lines, start=start_idx + 1):
lines_with_no.append(f"{line_num:6d}\t{line}")
# Build result
output = "\n".join(lines_with_no)
message = (
f"{len(truncated_lines)} lines read from file starting from line {start_idx + 1}."
)
if max_bytes_reached:
message += f" Max {self._max_bytes} bytes reached."
elif end_idx < len(lines):
message += f" File has {len(lines)} lines total."
if truncated_line_numbers:
message += f" Lines {truncated_line_numbers} were truncated."
return ToolOk(output=output, message=message)
except Exception as e:
return ToolError(
message=f"Failed to read {params.path}. Error: {e}",
)

View File

@@ -0,0 +1,183 @@
"""ReadMediaFile tool for AgentLite.
This module provides a tool for reading image and video files.
"""
from __future__ import annotations
from typing import Optional
import base64
from pathlib import Path
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the ReadMediaFile tool."""
path: str = Field(
description=(
"The path to the media file to read. "
"Absolute paths are required when reading files outside the working directory."
)
)
class ReadMediaFile(CallableTool2[Params]):
"""Tool for reading image and video files.
This tool reads media files and returns them as base64-encoded data URLs.
Supports images (PNG, JPEG, GIF, etc.) and videos.
Example:
>>> tool = ReadMediaFile(work_dir=Path("/tmp"))
>>> result = await tool({"path": "image.png"})
"""
name: str = "ReadMediaFile"
description: str = (
"Read an image or video file and return it as a base64-encoded data URL. "
"Supported formats: PNG, JPEG, GIF, WebP, MP4, WebM, and others. "
"Maximum file size: 100MB."
)
params: type[Params] = Params
# Supported media types
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".svg"}
VIDEO_EXTENSIONS = {".mp4", ".webm", ".mov", ".avi", ".mkv"}
# MIME type mapping
MIME_TYPES = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
".svg": "image/svg+xml",
".mp4": "video/mp4",
".webm": "video/webm",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".mkv": "video/x-matroska",
}
def __init__(
self,
work_dir: Path,
max_size_mb: int = 100,
):
"""Initialize the ReadMediaFile tool.
Args:
work_dir: The working directory for relative paths
max_size_mb: Maximum file size in MB
"""
super().__init__()
self._work_dir = work_dir
self._max_size = max_size_mb * 1024 * 1024
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
def _get_mime_type(self, path: Path) -> Optional[str]:
"""Get MIME type for a file based on extension."""
ext = path.suffix.lower()
return self.MIME_TYPES.get(ext)
def _is_media_file(self, path: Path) -> bool:
"""Check if a file is a supported media file."""
ext = path.suffix.lower()
return ext in self.IMAGE_EXTENSIONS or ext in self.VIDEO_EXTENSIONS
async def __call__(self, params: Params) -> ToolResult:
"""Execute the read media operation.
Args:
params: The read parameters
Returns:
ToolResult with base64 data URL or error
"""
if not params.path:
return ToolError(
message="File path cannot be empty.",
)
try:
# Resolve path
path = Path(params.path).expanduser()
if not path.is_absolute():
path = self._work_dir / path
path = path.resolve()
# Security check
if not self._is_within_work_dir(path) and not Path(params.path).is_absolute():
return ToolError(
message=(
f"`{params.path}` is not an absolute path. "
"You must provide an absolute path to read a file "
"outside the working directory."
),
)
# Check file exists
if not path.exists():
return ToolError(
message=f"`{params.path}` does not exist.",
)
if not path.is_file():
return ToolError(
message=f"`{params.path}` is not a file.",
)
# Check it's a media file
if not self._is_media_file(path):
return ToolError(
message=(
f"`{params.path}` is not a supported media file. "
f"Supported extensions: "
f"{', '.join(sorted(self.IMAGE_EXTENSIONS | self.VIDEO_EXTENSIONS))}"
),
)
# Check file size
file_size = path.stat().st_size
if file_size > self._max_size:
return ToolError(
message=(
f"`{params.path}` is too large ({file_size / 1024 / 1024:.1f}MB). "
f"Maximum size is {self._max_size / 1024 / 1024:.0f}MB."
),
)
# Get MIME type
mime_type = self._get_mime_type(path)
if not mime_type:
return ToolError(
message=f"Could not determine MIME type for `{params.path}`.",
)
# Read and encode file
data = path.read_bytes()
encoded = base64.b64encode(data).decode("ascii")
data_url = f"data:{mime_type};base64,{encoded}"
return ToolOk(
output=data_url,
message=(
f"Loaded {mime_type.split('/')[0]} file `{params.path}` ({file_size} bytes)."
),
)
except Exception as e:
return ToolError(
message=f"Failed to read {params.path}. Error: {e}",
)

View File

@@ -0,0 +1,189 @@
"""StrReplaceFile tool for AgentLite.
This module provides a tool for editing files using string replacement.
"""
from __future__ import annotations
from pathlib import Path
from typing import Union
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Edit(BaseModel):
"""A single edit operation."""
old: str = Field(description="The old string to replace. Can be multi-line.")
new: str = Field(description="The new string to replace with. Can be multi-line.")
replace_all: bool = Field(
description="Whether to replace all occurrences.",
default=False,
)
class Params(BaseModel):
"""Parameters for the StrReplaceFile tool."""
path: str = Field(
description=(
"The path to the file to edit. Absolute paths are required when editing files "
"outside the working directory."
)
)
edit: Union[Edit, list[Edit]] = Field(
description=(
"The edit(s) to apply to the file. "
"You can provide a single edit or a list of edits here."
),
)
class StrReplaceFile(CallableTool2[Params]):
"""Tool for editing files using string replacement.
This tool replaces strings in a file. It can perform single or multiple
replacements, and optionally replace all occurrences.
Example:
>>> tool = StrReplaceFile(work_dir=Path("/tmp"))
>>> result = await tool({"path": "test.txt", "edit": {"old": "Hello", "new": "Hi"}})
"""
name: str = "StrReplaceFile"
description: str = (
"Edit a file by replacing strings. "
"Supports single or multiple edits, and can replace all occurrences. "
"The old string must match exactly (including whitespace)."
)
params: type[Params] = Params
def __init__(
self,
work_dir: Path,
allow_outside_work_dir: bool = False,
):
"""Initialize the StrReplaceFile tool.
Args:
work_dir: The working directory for relative paths
allow_outside_work_dir: Whether to allow editing outside the working directory
"""
super().__init__()
self._work_dir = work_dir
self._allow_outside_work_dir = allow_outside_work_dir
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
def _apply_edit(self, content: str, edit: Edit) -> tuple[str, int]:
"""Apply a single edit to the content.
Args:
content: The original content
edit: The edit to apply
Returns:
Tuple of (new_content, replacements_count)
"""
if edit.replace_all:
count = content.count(edit.old)
new_content = content.replace(edit.old, edit.new)
return new_content, count
else:
if edit.old in content:
new_content = content.replace(edit.old, edit.new, 1)
return new_content, 1
return content, 0
async def __call__(self, params: Params) -> ToolResult:
"""Execute the string replacement operation.
Args:
params: The edit parameters
Returns:
ToolResult with success message or error
"""
if not params.path:
return ToolError(
message="File path cannot be empty.",
)
try:
# Resolve path
path = Path(params.path).expanduser()
if not path.is_absolute():
path = self._work_dir / path
path = path.resolve()
# Security check
if not self._is_within_work_dir(path):
if not Path(params.path).is_absolute():
return ToolError(
message=(
f"`{params.path}` is not an absolute path. "
"You must provide an absolute path to edit a file "
"outside the working directory."
),
)
if not self._allow_outside_work_dir:
return ToolError(
message=(
f"Editing outside the working directory is not allowed. "
f"Path: {params.path}"
),
)
# Check file exists
if not path.exists():
return ToolError(
message=f"`{params.path}` does not exist.",
)
if not path.is_file():
return ToolError(
message=f"`{params.path}` is not a file.",
)
# Read file content
content = path.read_text(encoding="utf-8", errors="replace")
original_content = content
# Normalize edits to list
edits = [params.edit] if isinstance(params.edit, Edit) else params.edit
# Apply edits
total_replacements = 0
for edit in edits:
content, count = self._apply_edit(content, edit)
total_replacements += count
# Check if any changes were made
if content == original_content:
return ToolError(
message="No replacements were made. The old string was not found in the file.",
)
# Write back
path.write_text(content, encoding="utf-8")
return ToolOk(
output="",
message=(
f"File successfully edited. "
f"Applied {len(edits)} edit(s) with {total_replacements} total replacement(s)."
),
)
except Exception as e:
return ToolError(
message=f"Failed to edit {params.path}. Error: {e}",
)

View File

@@ -0,0 +1,157 @@
"""WriteFile tool for AgentLite.
This module provides a tool for writing files to the filesystem.
"""
from __future__ import annotations
from typing import Literal
from pathlib import Path
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the WriteFile tool."""
path: str = Field(
description=(
"The path to the file to write. Absolute paths are required when writing files "
"outside the working directory."
)
)
content: str = Field(description="The content to write to the file")
mode: Literal["overwrite", "append"] = Field(
description=(
"The mode to use to write to the file. "
"Two modes are supported: `overwrite` for overwriting the whole file and "
"`append` for appending to the end of an existing file."
),
default="overwrite",
)
class WriteFile(CallableTool2[Params]):
"""Tool for writing files to the filesystem.
This tool writes content to a file, either overwriting or appending.
Example:
>>> tool = WriteFile(work_dir=Path("/tmp"))
>>> result = await tool({"path": "test.txt", "content": "Hello World"})
"""
name: str = "WriteFile"
description: str = (
"Write content to a file on the local filesystem. "
"Can create new files or overwrite/append to existing files."
)
params: type[Params] = Params
def __init__(
self,
work_dir: Path,
allow_outside_work_dir: bool = False,
):
"""Initialize the WriteFile tool.
Args:
work_dir: The working directory for relative paths
allow_outside_work_dir: Whether to allow writing outside the working directory
"""
super().__init__()
self._work_dir = work_dir
self._allow_outside_work_dir = allow_outside_work_dir
def _is_within_work_dir(self, path: Path) -> bool:
"""Check if a path is within the working directory."""
try:
path.relative_to(self._work_dir.resolve())
return True
except ValueError:
return False
async def __call__(self, params: Params) -> ToolResult:
"""Execute the write file operation.
Args:
params: The write parameters
Returns:
ToolResult with success message or error
"""
if not params.path:
return ToolError(
message="File path cannot be empty.",
)
try:
# Resolve path
path = Path(params.path).expanduser()
if not path.is_absolute():
path = self._work_dir / path
path = path.resolve()
# Security check
if not self._is_within_work_dir(path):
if not Path(params.path).is_absolute():
return ToolError(
message=(
f"`{params.path}` is not an absolute path. "
"You must provide an absolute path to write a file "
"outside the working directory."
),
)
if not self._allow_outside_work_dir:
return ToolError(
message=(
f"Writing outside the working directory is not allowed. "
f"Path: {params.path}"
),
)
# Check parent directory exists
if not path.parent.exists():
return ToolError(
message=f"Parent directory `{path.parent}` does not exist.",
)
# Check valid mode
if params.mode not in ("overwrite", "append"):
return ToolError(
message=f"Invalid mode: {params.mode}. Must be 'overwrite' or 'append'.",
)
# Check if file exists
file_existed = path.exists()
old_content = ""
if file_existed and path.is_file():
old_content = path.read_text(encoding="utf-8", errors="replace")
# Calculate new content
if params.mode == "append" and file_existed:
new_content = old_content + params.content
else:
new_content = params.content
# Write file
path.write_text(new_content, encoding="utf-8")
# Build success message
action = (
"overwritten"
if params.mode == "overwrite" and file_existed
else ("appended to" if params.mode == "append" and file_existed else "created")
)
file_size = path.stat().st_size
return ToolOk(
output="",
message=f"File `{params.path}` successfully {action}. Size: {file_size} bytes.",
)
except Exception as e:
return ToolError(
message=f"Failed to write to {params.path}. Error: {e}",
)

View File

@@ -0,0 +1,9 @@
"""Miscellaneous tools for AgentLite.
This module provides utility tools like todo lists and thinking.
"""
from agentlite.tools.misc.todo import SetTodoList
from agentlite.tools.misc.think import Think
__all__ = ["SetTodoList", "Think"]

View File

@@ -0,0 +1,69 @@
"""Think tool for AgentLite.
This module provides a tool for recording thoughts.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the Think tool."""
thought: str = Field(description="A thought to record")
class Think(CallableTool2[Params]):
"""Tool for recording thoughts.
This tool allows the agent to record its thinking process.
Useful for debugging and understanding the agent's reasoning.
Example:
>>> tool = Think()
>>> result = await tool({"thought": "I should first check if the file exists..."})
"""
name: str = "Think"
description: str = (
"Record a thought or reasoning step. "
"Use this to think through problems before taking action. "
"The thought will be logged but not returned to the user."
)
params: type[Params] = Params
def __init__(self):
"""Initialize the Think tool."""
super().__init__()
self._thoughts: list[str] = []
async def __call__(self, params: Params) -> ToolResult:
"""Execute the thought recording.
Args:
params: The thought parameters
Returns:
ToolResult with success message
"""
self._thoughts.append(params.thought)
return ToolOk(
output="",
message=f"Thought recorded ({len(self._thoughts)} total thoughts)",
)
def get_thoughts(self) -> list[str]:
"""Get all recorded thoughts.
Returns:
List of all recorded thoughts
"""
return self._thoughts.copy()
def clear_thoughts(self) -> None:
"""Clear all recorded thoughts."""
self._thoughts.clear()

View File

@@ -0,0 +1,101 @@
"""SetTodoList tool for AgentLite.
This module provides a tool for managing todo lists.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolOk, ToolResult
class Todo(BaseModel):
"""A single todo item."""
title: str = Field(description="The title of the todo", min_length=1)
status: Literal["pending", "in_progress", "done"] = Field(description="The status of the todo")
class Params(BaseModel):
"""Parameters for the SetTodoList tool."""
todos: list[Todo] = Field(description="The todo list to set")
class SetTodoList(CallableTool2[Params]):
"""Tool for managing todo lists.
This tool allows the agent to create and update a todo list.
The todo list can be used to track tasks and progress.
Example:
>>> tool = SetTodoList()
>>> result = await tool(
... {
... "todos": [
... {"title": "Read docs", "status": "done"},
... {"title": "Write code", "status": "in_progress"},
... ]
... }
... )
"""
name: str = "SetTodoList"
description: str = (
"Set or update the todo list. "
"Use this to track tasks and show progress. "
"Each todo has a title and status (pending/in_progress/done)."
)
params: type[Params] = Params
def __init__(self):
"""Initialize the SetTodoList tool."""
super().__init__()
self._todos: list[Todo] = []
async def __call__(self, params: Params) -> ToolResult:
"""Execute the todo list update.
Args:
params: The todo list parameters
Returns:
ToolResult with success message
"""
self._todos = params.todos
# Format output
lines = []
for todo in self._todos:
status_emoji = {
"pending": "",
"in_progress": "🔨",
"done": "",
}.get(todo.status, "")
lines.append(f"{status_emoji} {todo.title}")
output = "\n".join(lines) if lines else "No todos."
# Count by status
counts = {"pending": 0, "in_progress": 0, "done": 0}
for todo in self._todos:
if todo.status in counts:
counts[todo.status] += 1
message = (
f"Todo list updated: {len(self._todos)} items "
f"({counts['done']} done, {counts['in_progress']} in progress, "
f"{counts['pending']} pending)"
)
return ToolOk(output=output, message=message)
def get_todos(self) -> list[Todo]:
"""Get the current todo list.
Returns:
The current list of todos
"""
return self._todos.copy()

View File

@@ -0,0 +1,6 @@
"""Multi-agent tools for AgentLite.
This module provides tools for creating and managing subagents.
"""
__all__ = []

View File

@@ -0,0 +1,59 @@
"""CreateSubagent tool for AgentLite.
This module provides a tool for dynamically creating subagents.
In this rdev subagent integration, nested subagents are intentionally
disabled. The tool is kept for API compatibility but it intentionally
returns an explicit disabled error.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolResult
class Params(BaseModel):
"""Parameters for the CreateSubagent tool."""
name: str = Field(description="The name of the subagent to create")
prompt: str = Field(
description=(
"The system prompt for the subagent. "
"This defines the subagent's personality and capabilities."
),
)
class CreateSubagent(CallableTool2[Params]):
"""Tool for dynamically creating subagents.
This tool creates a new subagent with a custom system prompt.
The subagent can then be used with the Task tool.
Example:
>>> tool = CreateSubagent()
>>> result = await tool({"name": "researcher", "prompt": "You are a research assistant..."})
"""
name: str = "CreateSubagent"
description: str = (
"Create a new subagent with a custom system prompt. "
"The subagent can be used to perform specialized tasks. "
"Use the Task tool to run tasks with created subagents."
)
params: type[Params] = Params
def __init__(self):
"""Initialize the CreateSubagent tool."""
super().__init__()
async def __call__(self, params: Params) -> ToolResult:
"""Refuse to create nested subagents."""
return ToolError(
message=(
"CreateSubagent tool is disabled in this subagent runtime. "
f"Dynamic subagent creation is not allowed (requested '{params.name}')."
),
)

View File

@@ -0,0 +1,99 @@
"""Task tool for AgentLite.
This module provides a tool for delegating tasks to subagents.
In this rdev subagent integration, nested subagents are intentionally
disabled. The tool is kept for API compatibility but no longer executes
delegation.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolResult
if TYPE_CHECKING:
from agentlite.agent import Agent
from agentlite.labor_market import LaborMarket
class Params(BaseModel):
"""Parameters for the Task tool."""
subagent_name: str = Field(description="The name of the subagent to call (must be registered)")
prompt: str = Field(
description=(
"The task for the subagent to perform. "
"Provide detailed instructions with all necessary context."
),
)
description: str = Field(
default="",
description="A short (3-5 word) description of the task (for logging)",
)
class Task(CallableTool2[Params]):
"""Tool for delegating tasks to subagents.
This tool allows a parent agent to delegate tasks to its subagents.
The subagent must be registered in the parent's labor market.
Example:
>>> # Parent agent has a "coder" subagent
>>> tool = Task(parent_agent)
>>> result = await tool(
... {
... "subagent_name": "coder",
... "prompt": "Write a Python function to sort a list",
... "description": "Write sorting function",
... }
... )
"""
name: str = "Task"
description: str = (
"Delegate a task to a specialized subagent. "
"The subagent must be registered in the parent agent's labor market. "
"The subagent will execute independently and return its findings."
)
params: type[Params] = Params
def __init__(
self,
labor_market: LaborMarket | None = None,
parent_agent: Agent | None = None,
max_iterations: int = 80,
):
"""Initialize the Task tool.
Args:
labor_market: The LaborMarket containing subagents
parent_agent: Alternative: the parent agent (uses its labor_market)
max_iterations: Maximum iterations for subagent execution
Raises:
ValueError: If neither labor_market nor parent_agent is provided.
"""
super().__init__()
if labor_market is not None:
self._labor_market = labor_market
elif parent_agent is not None:
self._labor_market = parent_agent.labor_market
else:
raise ValueError("Either labor_market or parent_agent must be provided")
self._max_iterations = max_iterations
async def __call__(self, params: Params) -> ToolResult:
"""Refuse to execute nested subagent delegation."""
return ToolError(
message=(
"Task tool is disabled in this subagent runtime. "
f"Nested subagent delegation is not allowed (requested '{params.subagent_name}')."
),
)

View File

@@ -0,0 +1,8 @@
"""Shell tools for AgentLite.
This module provides tools for executing shell commands.
"""
from agentlite.tools.shell.shell import Shell
__all__ = ["Shell"]

View File

@@ -0,0 +1,164 @@
"""Shell tool for AgentLite.
This module provides a tool for executing shell commands.
"""
from __future__ import annotations
from typing import Optional
import asyncio
import platform
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the Shell tool."""
command: str = Field(description="The shell command to execute.")
timeout: int = Field(
description=(
"The timeout in seconds for the command to execute. "
"If the command takes longer than this, it will be killed."
),
default=60,
ge=1,
le=3600,
)
class Shell(CallableTool2[Params]):
"""Tool for executing shell commands.
This tool executes shell commands and returns their output.
Supports configurable timeout and command blocking for security.
Example:
>>> tool = Shell()
>>> result = await tool({"command": "ls -la"})
"""
name: str = "Shell"
description: str = (
"Execute a shell command and return its output. "
"Supports bash on Unix/Linux/macOS and PowerShell on Windows. "
"Use with caution - commands are executed with user permissions."
)
params: type[Params] = Params
def __init__(
self,
timeout: int = 60,
max_timeout: int = 300,
blocked_commands: Optional[list[str]] = None,
):
"""Initialize the Shell tool.
Args:
timeout: Default timeout in seconds
max_timeout: Maximum allowed timeout
blocked_commands: List of command patterns to block
"""
super().__init__()
self._default_timeout = timeout
self._max_timeout = max_timeout
self._blocked_commands = blocked_commands or []
self._is_windows = platform.system() == "Windows"
def _is_blocked(self, command: str) -> Optional[str]:
"""Check if a command is blocked.
Args:
command: The command to check
Returns:
Block reason if blocked, None otherwise
"""
cmd_lower = command.lower().strip()
for blocked in self._blocked_commands:
if blocked.lower() in cmd_lower:
return f"Command contains blocked pattern: {blocked}"
return None
async def __call__(self, params: Params) -> ToolResult:
"""Execute the shell command.
Args:
params: The command parameters
Returns:
ToolResult with command output or error
"""
if not params.command:
return ToolError(
message="Command cannot be empty.",
)
# Check if blocked
if block_reason := self._is_blocked(params.command):
return ToolError(
message=f"Command blocked: {block_reason}",
)
# Validate timeout
timeout = min(params.timeout, self._max_timeout)
try:
# Determine shell
if self._is_windows:
# Use PowerShell on Windows
shell_cmd = ["powershell", "-Command", params.command]
else:
# Use bash on Unix/Linux/macOS
shell_cmd = ["bash", "-c", params.command]
# Execute command
process = await asyncio.create_subprocess_exec(
*shell_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=timeout,
)
except asyncio.TimeoutError:
process.kill()
await process.wait()
return ToolError(
message=f"Command timed out after {timeout} seconds.",
)
# Decode output
stdout_str = stdout.decode("utf-8", errors="replace")
stderr_str = stderr.decode("utf-8", errors="replace")
# Build output
output_parts = []
if stdout_str:
output_parts.append(stdout_str)
if stderr_str:
output_parts.append(f"[stderr]\n{stderr_str}")
output = "\n".join(output_parts)
if process.returncode == 0:
return ToolOk(
output=output,
message="Command executed successfully (exit code 0).",
)
else:
return ToolError(
message=f"Command failed with exit code {process.returncode}.",
output=output,
)
except Exception as e:
return ToolError(
message=f"Failed to execute command. Error: {e}",
)

View File

@@ -0,0 +1,8 @@
"""Web tools for AgentLite.
This module provides tools for web access and search.
"""
from agentlite.tools.web.fetch import FetchURL
__all__ = ["FetchURL"]

View File

@@ -0,0 +1,173 @@
"""FetchURL tool for AgentLite.
This module provides a tool for fetching web page content.
"""
from __future__ import annotations
import urllib.request
import urllib.error
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolOk, ToolResult
class Params(BaseModel):
"""Parameters for the FetchURL tool."""
url: str = Field(description="The URL to fetch content from.")
class FetchURL(CallableTool2[Params]):
"""Tool for fetching web page content.
This tool fetches the content of a web page and extracts the main text.
Uses simple HTTP GET with configurable timeout.
Example:
>>> tool = FetchURL()
>>> result = await tool({"url": "https://example.com"})
"""
name: str = "FetchURL"
description: str = (
"Fetch the content of a web page. "
"Returns the HTML content or extracts main text if possible. "
"Useful for reading documentation, articles, or API responses."
)
params: type[Params] = Params
def __init__(
self,
timeout: int = 30,
user_agent: str = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
"(KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
),
max_content_length: int = 1024 * 1024, # 1MB
):
"""Initialize the FetchURL tool.
Args:
timeout: Request timeout in seconds
user_agent: User-Agent string
max_content_length: Maximum content length to fetch
"""
super().__init__()
self._timeout = timeout
self._user_agent = user_agent
self._max_content_length = max_content_length
def _extract_text(self, html: str) -> str:
"""Simple HTML to text extraction.
Args:
html: HTML content
Returns:
Extracted text
"""
import re
# Remove script and style elements
html = re.sub(r"<script[^\u003e]*>.*?</script>", "", html, flags=re.DOTALL)
html = re.sub(r"<style[^\u003e]*>.*?</style>", "", html, flags=re.DOTALL)
# Remove HTML tags
text = re.sub(r"<[^\u003e]+>", "", html)
# Decode HTML entities
import html as html_module
text = html_module.unescape(text)
# Normalize whitespace
text = re.sub(r"\s+", " ", text)
return text.strip()
async def __call__(self, params: Params) -> ToolResult:
"""Execute the URL fetch.
Args:
params: The fetch parameters
Returns:
ToolResult with page content or error
"""
if not params.url:
return ToolError(
message="URL cannot be empty.",
)
try:
# Create request with headers
request = urllib.request.Request(
params.url,
headers={
"User-Agent": self._user_agent,
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Accept-Encoding": "identity",
},
)
# Fetch URL
with urllib.request.urlopen(request, timeout=self._timeout) as response:
# Check content length
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > self._max_content_length:
return ToolError(
message=(
f"Content too large ({int(content_length)} bytes). "
f"Maximum is {self._max_content_length} bytes."
),
)
# Read content
content = response.read()
# Check size limit
if len(content) > self._max_content_length:
return ToolError(
message=(
f"Content too large ({len(content)} bytes). "
f"Maximum is {self._max_content_length} bytes."
),
)
# Decode content
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
try:
text = content.decode("latin-1")
except UnicodeDecodeError:
text = content.decode("utf-8", errors="replace")
# Extract text if HTML
content_type = response.headers.get("Content-Type", "")
if "text/html" in content_type:
extracted = self._extract_text(text)
return ToolOk(
output=extracted,
message=f"Fetched and extracted content from {params.url}",
)
else:
return ToolOk(
output=text,
message=f"Fetched content from {params.url}",
)
except urllib.error.HTTPError as e:
return ToolError(
message=f"HTTP error {e.code}: {e.reason}",
)
except urllib.error.URLError as e:
return ToolError(
message=f"URL error: {e.reason}",
)
except Exception as e:
return ToolError(
message=f"Failed to fetch {params.url}. Error: {e}",
)

View File

@@ -0,0 +1,82 @@
"""SearchWeb tool for AgentLite.
This module provides a tool for web search.
Note: This is a placeholder implementation. A real implementation would
require integration with a search API like Google, Bing, or DuckDuckGo.
"""
from __future__ import annotations
from pydantic import BaseModel, Field
from agentlite.tool import CallableTool2, ToolError, ToolResult
class Params(BaseModel):
"""Parameters for the SearchWeb tool."""
query: str = Field(description="The search query string.")
num_results: int = Field(
description="Number of search results to return (max 10).",
default=5,
ge=1,
le=10,
)
class SearchWeb(CallableTool2[Params]):
"""Tool for web search.
This tool performs a web search and returns relevant results.
Note: This is a placeholder implementation. To use real search functionality,
you need to integrate with a search API (Google, Bing, DuckDuckGo, etc.)
and set the appropriate API keys.
Example:
>>> tool = SearchWeb()
>>> result = await tool({"query": "Python async programming"})
"""
name: str = "SearchWeb"
description: str = (
"Search the web for information. "
"Returns a list of relevant search results with titles and snippets. "
"Note: Requires search API configuration to work properly."
)
params: type[Params] = Params
def __init__(
self,
timeout: int = 30,
user_agent: str = ("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"),
):
"""Initialize the SearchWeb tool.
Args:
timeout: Request timeout in seconds
user_agent: User-Agent string
"""
super().__init__()
self._timeout = timeout
self._user_agent = user_agent
async def __call__(self, params: Params) -> ToolResult:
"""Execute the web search.
Args:
params: The search parameters
Returns:
ToolResult with search results or error
"""
if not params.query:
return ToolError(message="Search query cannot be empty.")
return ToolError(
message=(
"SearchWeb tool is disabled in this subagent runtime. "
"Use FetchURL for direct URL content retrieval."
),
)

View File

329
agentlite/tests/conftest.py Normal file
View File

@@ -0,0 +1,329 @@
"""Test configuration and shared fixtures for AgentLite tests.
This module provides pytest configuration and fixtures that are shared
across all test modules.
"""
from __future__ import annotations
import asyncio
import json
from collections.abc import AsyncIterator, Sequence
from typing import Any, Optional
import pytest
from agentlite import (
Agent,
ContentPart,
Message,
TextPart,
ToolCall,
tool,
)
from agentlite.provider import StreamedMessage, TokenUsage
from agentlite.tool import Tool
# =============================================================================
# pytest Configuration
# =============================================================================
def pytest_configure(config):
"""Configure pytest with custom markers."""
config.addinivalue_line("markers", "unit: Unit tests")
config.addinivalue_line("markers", "integration: Integration tests")
config.addinivalue_line("markers", "scenario: Real-world scenario tests")
config.addinivalue_line("markers", "slow: Slow tests that may take time")
# =============================================================================
# Mock Provider Implementation
# =============================================================================
class MockStreamedMessage:
"""Mock streamed message for testing."""
def __init__(self, parts: list[ContentPart]):
self._parts = parts
self._id = "mock-msg-123"
self._usage = TokenUsage(input_tokens=10, output_tokens=5)
def __aiter__(self) -> AsyncIterator[ContentPart]:
"""Return async iterator over parts."""
return self._iter_parts()
async def _iter_parts(self) -> AsyncIterator[ContentPart]:
"""Iterate over parts."""
for part in self._parts:
yield part
@property
def id(self) -> Optional[str]:
"""Message ID."""
return self._id
@property
def usage(self) -> Optional[TokenUsage]:
"""Token usage."""
return self._usage
class MockProvider:
"""Mock provider for testing AgentLite without real API calls.
This provider simulates OpenAI API responses and allows:
- Configuring response sequences
- Simulating tool calls
- Simulating errors
- Tracking all calls for verification
Example:
provider = MockProvider()
provider.add_text_response("Hello!")
provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=provider)
response = await agent.run("Hi")
# Verify calls
assert len(provider.calls) == 1
assert provider.calls[0]["system_prompt"] == "You are helpful."
"""
def __init__(self):
self.responses: list[dict[str, Any]] = []
self.calls: list[dict[str, Any]] = []
self.model = "mock-model"
def add_text_response(self, text: str) -> None:
"""Add a text response to the queue."""
self.responses.append({"type": "text", "content": text})
def add_text_responses(self, *texts: str) -> None:
"""Add multiple text responses to the queue."""
for text in texts:
self.add_text_response(text)
def add_tool_call(self, name: str, arguments: dict[str, Any], result: str) -> None:
"""Add a tool call response to the queue."""
self.responses.append(
{"type": "tool_call", "name": name, "arguments": arguments, "result": result}
)
def add_tool_calls(self, calls: list[dict[str, Any]]) -> None:
"""Add multiple tool calls to the queue."""
for call in calls:
self.add_tool_call(call["name"], call["arguments"], call.get("result", ""))
def add_error(self, error: Exception) -> None:
"""Add an error response to the queue."""
self.responses.append({"type": "error", "error": error})
def clear_responses(self) -> None:
"""Clear all pending responses."""
self.responses.clear()
@property
def model_name(self) -> str:
"""Model name."""
return self.model
async def generate(
self,
system_prompt: str,
tools: Sequence[Tool],
history: Sequence[Message],
) -> StreamedMessage:
"""Generate a mock response."""
self.calls.append(
{
"system_prompt": system_prompt,
"tools": list(tools),
"history": list(history),
}
)
if not self.responses:
return MockStreamedMessage([TextPart(text="Mock response")])
response = self.responses.pop(0)
if response["type"] == "error":
raise response["error"]
elif response["type"] == "text":
return MockStreamedMessage([TextPart(text=response["content"])])
elif response["type"] == "tool_call":
return MockStreamedMessage(
[
ToolCall(
id="call_123",
function=ToolCall.FunctionBody(
name=response["name"], arguments=json.dumps(response["arguments"])
),
)
]
)
else:
return MockStreamedMessage([TextPart(text="Unknown response type")])
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture
def mock_provider():
"""Create a mock provider with no responses configured."""
return MockProvider()
@pytest.fixture
def mock_provider_with_response():
"""Create a mock provider that returns a simple text response."""
provider = MockProvider()
provider.add_text_response("Hello!")
return provider
@pytest.fixture
def mock_provider_with_sequence():
"""Create a mock provider with multiple responses configured."""
provider = MockProvider()
provider.add_text_responses("Response 1", "Response 2", "Response 3")
return provider
# =============================================================================
# Message Fixtures
# =============================================================================
@pytest.fixture
def sample_text_message():
"""Create a sample text message."""
return Message(role="user", content="Hello!")
@pytest.fixture
def sample_assistant_message():
"""Create a sample assistant message."""
return Message(role="assistant", content="Hi there!")
@pytest.fixture
def sample_tool_call():
"""Create a sample tool call."""
return ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
@pytest.fixture
def sample_tool_message():
"""Create a sample tool response message."""
return Message(role="tool", content="3", tool_call_id="call_123")
# =============================================================================
# Tool Fixtures
# =============================================================================
@pytest.fixture
def add_tool():
"""Create a simple add tool."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
return add
@pytest.fixture
def multiply_tool():
"""Create a multiply tool."""
@tool()
async def multiply(a: float, b: float) -> float:
"""Multiply two numbers."""
return a * b
return multiply
@pytest.fixture
def error_tool():
"""Create a tool that raises an error."""
@tool()
async def error() -> str:
"""Always raises an error."""
raise ValueError("Test error")
return error
@pytest.fixture
def slow_tool():
"""Create a tool that takes some time."""
@tool()
async def slow_operation(duration: float = 0.1) -> str:
"""Simulate a slow operation."""
await asyncio.sleep(duration)
return f"Completed after {duration}s"
return slow_operation
# =============================================================================
# Agent Fixtures
# =============================================================================
@pytest.fixture
async def simple_agent(mock_provider):
"""Create a simple agent with mocked provider."""
return Agent(provider=mock_provider)
@pytest.fixture
async def agent_with_tools(mock_provider, add_tool):
"""Create an agent with tools."""
return Agent(provider=mock_provider, tools=[add_tool])
@pytest.fixture
async def agent_with_multiple_tools(mock_provider, add_tool, multiply_tool):
"""Create an agent with multiple tools."""
return Agent(provider=mock_provider, tools=[add_tool, multiply_tool])
# =============================================================================
# Utility Fixtures
# =============================================================================
@pytest.fixture
def sample_conversation():
"""Create a sample conversation history."""
return [
Message(role="user", content="Hello!"),
Message(role="assistant", content="Hi there! How can I help?"),
Message(role="user", content="What is 2+2?"),
Message(role="assistant", content="2+2=4"),
]
@pytest.fixture
def event_loop():
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()

View File

View File

@@ -0,0 +1,286 @@
"""Integration tests for Agent class.
This module tests the Agent class with mocked providers to verify
core functionality without making real API calls.
"""
from __future__ import annotations
import pytest
from agentlite import Agent
@pytest.mark.integration
class TestAgentInitialization:
"""Tests for Agent initialization."""
def test_agent_initialization(self, mock_provider):
"""Test basic agent creation."""
agent = Agent(provider=mock_provider)
assert agent.provider is mock_provider
assert agent.system_prompt == "You are a helpful assistant."
assert agent.max_iterations == 80
assert agent.history == []
def test_agent_with_custom_system_prompt(self, mock_provider):
"""Test agent creation with custom system prompt."""
agent = Agent(provider=mock_provider, system_prompt="You are a specialized assistant.")
assert agent.system_prompt == "You are a specialized assistant."
def test_agent_with_tools(self, mock_provider, add_tool):
"""Test agent creation with tools."""
agent = Agent(provider=mock_provider, tools=[add_tool])
assert len(agent.tools.tools) == 1
assert agent.tools.tools[0].name == "add"
def test_agent_with_custom_max_iterations(self, mock_provider):
"""Test agent with custom max_iterations."""
agent = Agent(provider=mock_provider, max_iterations=5)
assert agent.max_iterations == 5
@pytest.mark.integration
class TestAgentRun:
"""Tests for Agent.run() method."""
@pytest.mark.asyncio
async def test_agent_run_simple(self, mock_provider):
"""Test simple non-streaming run."""
mock_provider.add_text_response("Hello there!")
agent = Agent(provider=mock_provider)
response = await agent.run("Hi")
assert response == "Hello there!"
@pytest.mark.asyncio
async def test_agent_run_adds_to_history(self, mock_provider):
"""Test that run adds messages to history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
# History should have user message and assistant response
assert len(agent.history) == 2
assert agent.history[0].role == "user"
assert agent.history[0].extract_text() == "Hello"
assert agent.history[1].role == "assistant"
@pytest.mark.asyncio
async def test_agent_run_multiple_messages(self, mock_provider):
"""Test multiple runs accumulate history."""
mock_provider.add_text_responses("Response 1", "Response 2")
agent = Agent(provider=mock_provider)
await agent.run("Message 1")
await agent.run("Message 2")
# Should have 4 messages total
assert len(agent.history) == 4
assert agent.history[0].role == "user"
assert agent.history[1].role == "assistant"
assert agent.history[2].role == "user"
assert agent.history[3].role == "assistant"
@pytest.mark.asyncio
async def test_agent_run_tracks_calls(self, mock_provider):
"""Test that provider.generate is called during run."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
assert len(mock_provider.calls) == 1
call = mock_provider.calls[0]
assert call["system_prompt"] == "You are a helpful assistant."
assert len(call["history"]) == 1 # User message
@pytest.mark.integration
class TestAgentGenerate:
"""Tests for Agent.generate() method."""
@pytest.mark.asyncio
async def test_agent_generate_returns_message(self, mock_provider):
"""Test that generate returns a Message."""
mock_provider.add_text_response("Generated response")
agent = Agent(provider=mock_provider)
message = await agent.generate("Hello")
assert message.role == "assistant"
assert message.extract_text() == "Generated response"
@pytest.mark.asyncio
async def test_agent_generate_without_tool_loop(self, mock_provider):
"""Test that generate doesn't do tool calling loop."""
# Add tool call response
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=mock_provider, tools=[])
message = await agent.generate("Calculate 1+2")
# Should return the tool call without executing it
assert message.has_tool_calls()
assert len(message.tool_calls) == 1
assert message.tool_calls[0].function.name == "add"
@pytest.mark.asyncio
async def test_agent_generate_adds_to_history(self, mock_provider):
"""Test that generate adds response to history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.generate("Hello")
assert len(agent.history) == 2
assert agent.history[1].role == "assistant"
@pytest.mark.integration
class TestAgentHistory:
"""Tests for Agent history management."""
@pytest.mark.asyncio
async def test_agent_history_property_returns_copy(self, mock_provider):
"""Test that history property returns a copy."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
history = agent.history
history.clear() # Modify the copy
# Original should still have messages
assert len(agent.history) == 2
@pytest.mark.asyncio
async def test_agent_clear_history(self, mock_provider):
"""Test clearing history."""
mock_provider.add_text_response("Response!")
agent = Agent(provider=mock_provider)
await agent.run("Hello")
agent.clear_history()
assert agent.history == []
@pytest.mark.asyncio
async def test_agent_add_message(self, mock_provider):
"""Test manually adding a message."""
agent = Agent(provider=mock_provider)
from agentlite import Message
agent.add_message(Message(role="user", content="Manual message"))
assert len(agent.history) == 1
assert agent.history[0].extract_text() == "Manual message"
@pytest.mark.integration
class TestAgentWithTools:
"""Tests for Agent with tools."""
@pytest.mark.asyncio
async def test_agent_with_tools_initialization(self, mock_provider, add_tool):
"""Test agent initialization with tools."""
agent = Agent(
provider=mock_provider, tools=[add_tool], system_prompt="You have access to tools."
)
assert len(agent.tools.tools) == 1
# Run to verify tools are passed to provider
mock_provider.add_text_response("I have tools available")
await agent.run("Hello")
# Check that tools were passed to provider
assert len(mock_provider.calls) == 1
assert len(mock_provider.calls[0]["tools"]) == 1
@pytest.mark.asyncio
async def test_agent_tool_call_execution(self, mock_provider, add_tool):
"""Test that agent executes tool calls."""
# First response: tool call
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
# Second response: text after tool result
mock_provider.add_text_response("The sum is 3")
agent = Agent(provider=mock_provider, tools=[add_tool])
response = await agent.run("What is 1+2?")
assert "3" in response
# Should have made 2 calls to provider
assert len(mock_provider.calls) == 2
@pytest.mark.integration
class TestAgentMaxIterations:
"""Tests for max_iterations behavior."""
@pytest.mark.asyncio
async def test_agent_respects_max_iterations(self, mock_provider, add_tool):
"""Test that agent stops after max_iterations."""
# Always return tool calls to trigger iteration limit
for _ in range(10):
mock_provider.add_tool_call("add", {"a": 1, "b": 2}, "3")
agent = Agent(provider=mock_provider, tools=[add_tool], max_iterations=3)
response = await agent.run("Calculate")
# Should stop after max_iterations
assert len(mock_provider.calls) <= 3
assert "Maximum tool call iterations reached" in response
@pytest.mark.asyncio
async def test_agent_no_iterations_for_simple_response(self, mock_provider):
"""Test that simple responses don't count as iterations."""
mock_provider.add_text_response("Simple response")
agent = Agent(provider=mock_provider, max_iterations=1)
response = await agent.run("Hello")
assert response == "Simple response"
@pytest.mark.integration
class TestAgentStreaming:
"""Tests for streaming mode."""
@pytest.mark.asyncio
async def test_agent_run_streaming(self, mock_provider):
"""Test streaming run."""
mock_provider.add_text_response("Streamed response")
agent = Agent(provider=mock_provider)
stream = await agent.run("Hello", stream=True)
# Collect stream
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert "".join(chunks) == "Streamed response"
@pytest.mark.asyncio
async def test_agent_streaming_adds_to_history(self, mock_provider):
"""Test that streaming adds messages to history."""
mock_provider.add_text_response("Response")
agent = Agent(provider=mock_provider)
stream = await agent.run("Hello", stream=True)
async for _ in stream:
pass
assert len(agent.history) == 2

View File

@@ -0,0 +1,347 @@
"""Integration tests for AgentLite with real API.
This script runs comprehensive tests against the real OpenAI API.
Requires OPENAI_API_KEY environment variable to be set.
Usage:
export OPENAI_API_KEY="sk-..."
python tests/integration/test_with_api.py
"""
import asyncio
import os
import sys
from pathlib import Path
import pytest
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
from agentlite import Agent, OpenAIProvider, LLMClient
from agentlite.skills import discover_skills, SkillTool, index_skills_by_name
from agentlite.tools import ConfigurableToolset
# Test configuration
TEST_MODEL = "gpt-4o-mini" # Use mini for cost efficiency
HAS_OPENAI_API_KEY = bool(os.environ.get("OPENAI_API_KEY"))
pytestmark = pytest.mark.skipif(
not HAS_OPENAI_API_KEY, reason="OPENAI_API_KEY is required to run integration tests"
)
def get_provider():
"""Get OpenAI provider with API key."""
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("❌ OPENAI_API_KEY not set!")
print("Please set your OpenAI API key:")
print(" export OPENAI_API_KEY='sk-...'")
sys.exit(1)
return OpenAIProvider(api_key=api_key, model=TEST_MODEL)
async def test_basic_agent():
"""Test 1: Basic Agent functionality."""
print("\n" + "=" * 60)
print("Test 1: Basic Agent Functionality")
print("=" * 60)
try:
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant. Be concise.",
)
response = await agent.run("What is 2+2?")
print(f"✅ Agent responded: {response[:100]}...")
assert "4" in response, "Expected '4' in response"
print("✅ Basic Agent test PASSED")
return True
except Exception as e:
print(f"❌ Basic Agent test FAILED: {e}")
return False
async def test_agent_with_tools():
"""Test 2: Agent with tool suite."""
print("\n" + "=" * 60)
print("Test 2: Agent with Tool Suite")
print("=" * 60)
try:
from agentlite.tools import ToolSuiteConfig
provider = get_provider()
# Create toolset with file tools
config = ToolSuiteConfig()
toolset = ConfigurableToolset(config, work_dir=Path.cwd())
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant with file access.",
tools=toolset.tools,
)
print(f"✅ Agent created with {len(agent.tools.tools)} tools")
# Test simple query (without requiring file access)
response = await agent.run("List the Python files in the current directory")
print(f"✅ Agent with tools responded: {response[:100]}...")
print("✅ Agent with Tools test PASSED")
return True
except Exception as e:
print(f"❌ Agent with Tools test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_llm_client():
"""Test 3: LLMClient functionality."""
print("\n" + "=" * 60)
print("Test 3: LLMClient Functionality")
print("=" * 60)
try:
provider = get_provider()
client = LLMClient(provider=provider)
response = await client.complete(
user_prompt="What is the capital of France?",
system_prompt="You are a helpful assistant. Be concise.",
)
print(f"✅ LLMClient responded: {response.content[:100]}...")
assert "Paris" in response.content, "Expected 'Paris' in response"
print("✅ LLMClient test PASSED")
return True
except Exception as e:
print(f"❌ LLMClient test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_llm_streaming():
"""Test 4: LLM streaming."""
print("\n" + "=" * 60)
print("Test 4: LLM Streaming")
print("=" * 60)
try:
provider = get_provider()
client = LLMClient(provider=provider)
chunks = []
async for chunk in client.stream(
user_prompt="Count from 1 to 3",
system_prompt="You are a helpful assistant.",
):
chunks.append(chunk)
print(f" Chunk: {chunk[:20]}...")
full_response = "".join(chunks)
print(f"✅ Streamed response: {full_response[:100]}...")
print("✅ LLM Streaming test PASSED")
return True
except Exception as e:
print(f"❌ LLM Streaming test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_subagents():
"""Test 5: Subagent functionality."""
print("\n" + "=" * 60)
print("Test 5: Subagent Functionality")
print("=" * 60)
try:
from agentlite.tools.multiagent.task import Task
provider = get_provider()
# Create parent agent
parent = Agent(
provider=provider,
system_prompt="You are a coordinator agent.",
name="coordinator",
)
# Create subagent
coder = Agent(
provider=provider,
system_prompt="You are a coding specialist. Write clean, simple code.",
name="coder",
)
# Add subagent to parent
parent.add_subagent("coder", coder, "Writes code")
# Add Task tool
parent.tools.add(Task(labor_market=parent.labor_market))
print(f"✅ Created parent with {len(parent.labor_market)} subagent(s)")
print(f" Subagents: {parent.labor_market.list_subagents()}")
print("✅ Subagent test PASSED")
return True
except Exception as e:
print(f"❌ Subagent test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_skills():
"""Test 6: Skills functionality."""
print("\n" + "=" * 60)
print("Test 6: Skills Functionality")
print("=" * 60)
try:
# Discover example skills
skills_dir = Path(__file__).parent.parent.parent / "examples" / "skills"
if not skills_dir.exists():
print("⚠️ Skills directory not found, skipping")
return True
skills = discover_skills(skills_dir)
print(f"✅ Discovered {len(skills)} skill(s)")
for skill in skills:
print(f" - {skill.name} ({skill.type})")
if len(skills) == 0:
print("⚠️ No skills found, skipping further tests")
return True
# Test with agent
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant.",
)
skill_index = index_skills_by_name(skills)
skill_tool = SkillTool(skill_index, parent_agent=agent)
agent.tools.add(skill_tool)
print("✅ Added SkillTool to agent")
print("✅ Skills test PASSED")
return True
except Exception as e:
print(f"❌ Skills test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def test_conversation_history():
"""Test 7: Conversation history."""
print("\n" + "=" * 60)
print("Test 7: Conversation History")
print("=" * 60)
try:
provider = get_provider()
agent = Agent(
provider=provider,
system_prompt="You are a helpful assistant.",
)
# First message
response1 = await agent.run("My name is Alice")
print(f"✅ Response 1: {response1[:50]}...")
# Second message (should remember context)
response2 = await agent.run("What is my name?")
print(f"✅ Response 2: {response2[:50]}...")
assert "Alice" in response2, "Expected agent to remember name"
print("✅ Conversation History test PASSED")
return True
except Exception as e:
print(f"❌ Conversation History test FAILED: {e}")
import traceback
traceback.print_exc()
return False
async def run_all_tests():
"""Run all integration tests."""
print("\n" + "=" * 60)
print("AgentLite Integration Tests with Real API")
print("=" * 60)
print(f"Model: {TEST_MODEL}")
# Check API key
if not os.environ.get("OPENAI_API_KEY"):
print("\n❌ OPENAI_API_KEY not set!")
print("\nTo run these tests, set your OpenAI API key:")
print(" export OPENAI_API_KEY='sk-...'")
print("\nGet your API key from: https://platform.openai.com/api-keys")
return []
results = []
# Run all tests
results.append(("Basic Agent", await test_basic_agent()))
results.append(("Agent with Tools", await test_agent_with_tools()))
results.append(("LLMClient", await test_llm_client()))
results.append(("LLM Streaming", await test_llm_streaming()))
results.append(("Subagents", await test_subagents()))
results.append(("Skills", await test_skills()))
results.append(("Conversation History", await test_conversation_history()))
# Print summary
print("\n" + "=" * 60)
print("Test Summary")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✅ PASSED" if result else "❌ FAILED"
print(f"{status}: {name}")
print(f"\n{passed}/{total} tests passed")
if passed == total:
print("\n🎉 All tests passed!")
else:
print(f"\n⚠️ {total - passed} test(s) failed")
return results
if __name__ == "__main__":
results = asyncio.run(run_all_tests())
# Exit with error code if any tests failed
if results and not all(r for _, r in results):
sys.exit(1)

View File

View File

View File

@@ -0,0 +1,140 @@
"""Debug script to find CLI test hang cause."""
from __future__ import annotations
import os
import sys
import asyncio
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
from agentlite import Agent, OpenAIProvider
from agentlite.tools.shell.shell import Shell, Params
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
async def test_shell_directly():
"""Test shell tool without agent."""
print("\n=== Test 1: Shell tool directly ===")
shell = Shell(timeout=10)
# Use Params dataclass
result = await shell(Params(command="echo 'Hello'", timeout=5))
print(f"Result: {result}")
print(f"Output: {result.output if hasattr(result, 'output') else result}")
return True
async def test_agent_no_tools():
"""Test agent without tools."""
print("\n=== Test 2: Agent without tools ===")
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
print("SILICONFLOW_API_KEY not set")
return False
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=30.0,
)
agent = Agent(
provider=provider,
system_prompt="Reply briefly in one word.",
max_iterations=3,
)
print("Sending message to LLM...")
try:
response = await asyncio.wait_for(
agent.run("Say hello."),
timeout=60.0,
)
print(f"Response: {response[:100]}...")
return True
except asyncio.TimeoutError:
print("TIMEOUT in agent without tools!")
return False
async def test_agent_with_shell():
"""Test agent with shell tool - the problematic case."""
print("\n=== Test 3: Agent WITH shell tool ===")
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
print("SILICONFLOW_API_KEY not set")
return False
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=60.0,
)
agent = Agent(
provider=provider,
system_prompt="You are a shell assistant. Execute commands when asked. Keep responses brief.",
tools=[Shell(timeout=10)],
max_iterations=5, # Limit iterations
)
print("Sending message with tool request...")
print("This is where it might hang...")
try:
response = await asyncio.wait_for(
agent.run("Run 'echo test' and tell me the result."),
timeout=120.0,
)
print(f"Response: {response}")
return True
except asyncio.TimeoutError:
print("TIMEOUT! Agent hung for 120 seconds")
# Check history to see what happened
print(f"\nHistory length: {len(agent.history)}")
for i, msg in enumerate(agent.history[-5:]):
content_preview = str(msg.content)[:100] if msg.content else "None"
print(f" [{i}] {msg.role}: {content_preview}...")
return False
async def main():
"""Run all tests."""
print("=" * 60)
print("CLI Debug Test - Finding the hang cause")
print("=" * 60)
results = []
# Test 1: Shell directly
r1 = await test_shell_directly()
results.append(("Shell directly", r1))
print(f"Result: {'PASS' if r1 else 'FAIL'}")
# Test 2: Agent without tools
r2 = await test_agent_no_tools()
results.append(("Agent no tools", r2))
print(f"Result: {'PASS' if r2 else 'FAIL'}")
# Test 3: Agent with shell (the problem)
r3 = await test_agent_with_shell()
results.append(("Agent with shell", r3))
print(f"Result: {'PASS' if r3 else 'FAIL'}")
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for name, passed in results:
status = "✅ PASS" if passed else "❌ FAIL"
print(f" {name}: {status}")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,221 @@
"""Debug script with detailed logging to find CLI test hang cause."""
from __future__ import annotations
import os
import sys
import asyncio
import logging
import time
sys.path.insert(0, "/home/tcmofashi/proj/l2d_backend/agentlite/src")
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("debug")
# SiliconFlow DeepSeek-V3 (known good function calling support)
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Pro/deepseek-ai/DeepSeek-V3.2"
SILICONFLOW_API_KEY = "sk-eaxfgkkcuatochftxpevkyvltghigsrclzjzalybmaqycual"
async def main():
from agentlite import Agent, OpenAIProvider
from agentlite.tools.shell.shell import Shell
from agentlite.message import Message
logger.info("=" * 60)
logger.info("CLI Debug Test with DeepSeek-V3 (SiliconFlow)")
logger.info("=" * 60)
api_key = os.environ.get("SILICONFLOW_API_KEY") or SILICONFLOW_API_KEY
if not api_key:
logger.error("SILICONFLOW_API_KEY not set")
return
logger.info(f"Using model: {SILICONFLOW_MODEL}")
provider = OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
timeout=30.0,
)
agent = Agent(
provider=provider,
system_prompt="You are a shell assistant. Execute commands when asked. Reply briefly.",
tools=[Shell(timeout=10)],
max_iterations=5,
)
start_time = time.time()
message = "Run 'echo test' and tell me the result."
logger.info("\n=== Starting Agent Run ===")
logger.info(f"Message: {message}")
logger.info(f"Max iterations: {agent.max_iterations}")
logger.info(f"Tools: {[t.name for t in agent.tools.tools]}")
agent._history.append(Message(role="user", content=message))
iterations = 0
final_response = None
while iterations < agent.max_iterations:
iterations += 1
elapsed = time.time() - start_time
logger.info(f"\n{'=' * 50}")
logger.info(f"ITERATION {iterations}/{agent.max_iterations} (elapsed: {elapsed:.1f}s)")
logger.info(f"{'=' * 50}")
# Step 1: Call Provider
logger.info(">>> Step 1: Calling provider.generate()...")
step_start = time.time()
try:
stream = await asyncio.wait_for(
provider.generate(
system_prompt=agent.system_prompt,
tools=agent.tools.tools,
history=agent._history,
),
timeout=60.0,
)
logger.info(f"<<< Provider returned stream in {time.time() - step_start:.2f}s")
except asyncio.TimeoutError:
logger.error("!!! Provider call TIMEOUT after 60s")
final_response = "ERROR: Provider timeout"
break
# Step 2: Collect stream parts
logger.info(">>> Step 2: Collecting stream parts...")
step_start = time.time()
from agentlite.message import TextPart, ToolCall, ContentPart
response_parts = []
tool_calls = []
chunk_count = 0
try:
async for part in stream:
chunk_count += 1
if chunk_count % 10 == 0:
logger.debug(f" Received chunk #{chunk_count}")
if isinstance(part, ToolCall):
tool_calls.append(part)
logger.info(
f" ToolCall received: {part.function.name if hasattr(part, 'function') else part}"
)
elif isinstance(part, ContentPart):
response_parts.append(part)
if isinstance(part, TextPart):
logger.debug(f" Text: {part.text[:50]}...")
logger.info(
f"<<< Stream finished in {time.time() - step_start:.2f}s, {chunk_count} chunks"
)
except asyncio.TimeoutError:
logger.error("!!! Stream reading TIMEOUT")
final_response = "ERROR: Stream timeout"
break
except Exception as e:
logger.error(f"!!! Stream error: {type(e).__name__}: {e}")
final_response = f"ERROR: Stream error - {e}"
break
# Extract text
response_text = ""
for part in response_parts:
if isinstance(part, TextPart):
response_text += part.text
logger.info(f"Response text ({len(response_text)} chars): {response_text[:100]}...")
logger.info(f"Tool calls: {len(tool_calls)}")
# Add to history
agent._history.append(
Message(
role="assistant",
content=response_parts,
tool_calls=tool_calls if tool_calls else None,
)
)
# Step 3: Check if done
if not tool_calls:
elapsed = time.time() - start_time
logger.info(f"\n=== Agent completed in {elapsed:.2f}s, {iterations} iterations ===")
final_response = response_text
break
# Step 4: Execute tool calls
logger.info(f"\n>>> Step 3: Executing {len(tool_calls)} tool calls...")
step_start = time.time()
for i, tc in enumerate(tool_calls):
func_name = tc.function.name if hasattr(tc, "function") else str(tc)
func_args = tc.function.arguments if hasattr(tc, "function") else ""
logger.info(f" Tool #{i + 1}: {func_name}")
logger.info(f" Args: {func_args[:200]}...")
try:
result = await asyncio.wait_for(
agent.tools.handle(tc),
timeout=30.0,
)
output = result.output if hasattr(result, "output") else str(result)
is_error = result.is_error if hasattr(result, "is_error") else False
logger.info(
f" Result: is_error={is_error}, output_len={len(output) if output else 0}"
)
output_preview = output[:100] if output else "None"
logger.info(f" Output preview: {output_preview}...")
except asyncio.TimeoutError:
logger.error(" !!! Tool execution TIMEOUT")
output = "Tool execution timed out"
is_error = True
except Exception as e:
logger.error(f" !!! Tool error: {type(e).__name__}: {e}")
output = str(e)
is_error = True
# Add tool result to history
agent._history.append(
Message(
role="tool",
content=output,
tool_call_id=tc.id if hasattr(tc, "id") else f"tc_{i}",
)
)
logger.info(f"<<< Tool execution finished in {time.time() - step_start:.2f}s")
# Check overall timeout
elapsed = time.time() - start_time
if elapsed > 90:
logger.warning(f"!!! Overall timeout approaching ({elapsed:.1f}s)")
final_response = f"Timeout after {iterations} iterations"
break
if iterations >= agent.max_iterations:
logger.warning(f"!!! Max iterations reached ({agent.max_iterations})")
final_response = f"Max iterations ({agent.max_iterations}) reached"
logger.info(f"\n{'=' * 60}")
logger.info("FINAL RESULT:")
logger.info(f"{'=' * 60}")
logger.info(f"{final_response}")
logger.info(f"Total iterations: {iterations}")
logger.info(f"Total time: {time.time() - start_time:.2f}s")
logger.info(f"History length: {len(agent._history)}")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,344 @@
"""End-to-end test for complex CLI operations with real API.
This test simulates a realistic complex CLI task where an agent:
1. Explores project structure using shell commands
2. Searches for specific patterns using grep/glob
3. Reads relevant files
4. Creates analysis reports
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
"""
from __future__ import annotations
import asyncio
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, OpenAIProvider
from agentlite.tools import (
ConfigurableToolset,
ToolSuiteConfig,
)
# =============================================================================
# Configuration from model_config.toml
# =============================================================================
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
def get_siliconflow_provider() -> OpenAIProvider | None:
"""Create OpenAIProvider for SiliconFlow API."""
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
return None
return OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
)
@pytest.fixture
def real_provider():
"""Create real SiliconFlow provider."""
provider = get_siliconflow_provider()
if provider is None:
pytest.skip("SILICONFLOW_API_KEY not set")
return provider
@pytest.fixture
def test_project():
"""Create a mock project structure for testing."""
with tempfile.TemporaryDirectory() as tmpdir:
project_dir = Path(tmpdir) / "test_project"
project_dir.mkdir()
# Create project structure
(project_dir / "src").mkdir()
(project_dir / "src" / "utils").mkdir()
(project_dir / "tests").mkdir()
(project_dir / "docs").mkdir()
# Create source files
(project_dir / "src" / "main.py").write_text('''"""Main module."""
from src.utils.helper import process_data
from src.utils.logger import setup_logger
def main():
"""Main entry point."""
logger = setup_logger()
data = [1, 2, 3, 4, 5]
result = process_data(data)
logger.info(f"Result: {result}")
return result
if __name__ == "__main__":
main()
''')
(project_dir / "src" / "__init__.py").write_text('"""Source package."""')
(project_dir / "src" / "utils" / "helper.py").write_text('''"""Helper utilities."""
def process_data(data: list) -> list:
"""Process input data."""
return [x * 2 for x in data]
def validate_data(data: list) -> bool:
"""Validate data format."""
return all(isinstance(x, (int, float)) for x in data)
''')
(project_dir / "src" / "utils" / "logger.py").write_text('''"""Logging utilities."""
import logging
def setup_logger(name: str = "app") -> logging.Logger:
"""Setup application logger."""
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
return logger
''')
(project_dir / "src" / "utils" / "__init__.py").write_text('"""Utils package."""')
# Create test files
(project_dir / "tests" / "test_helper.py").write_text('''"""Tests for helper module."""
from src.utils.helper import process_data, validate_data
def test_process_data():
assert process_data([1, 2, 3]) == [2, 4, 6]
def test_validate_data():
assert validate_data([1, 2, 3]) == True
assert validate_data(["a", "b"]) == False
''')
# Create documentation
(project_dir / "docs" / "README.md").write_text("""# Test Project
A sample project for testing CLI operations.
## Structure
- `src/` - Source code
- `tests/` - Unit tests
- `docs/` - Documentation
""")
(project_dir / "README.md").write_text("""# Test Project
Simple data processing project.
## Usage
```bash
python -m src.main
```
""")
yield project_dir
@pytest.mark.scenario
@pytest.mark.slow
class TestComplexCLITasks:
"""End-to-end tests with complex CLI operations."""
@pytest.mark.asyncio
async def test_explore_project_structure(self, real_provider, test_project):
"""Test exploring project structure using CLI tools.
Task: Use shell commands to explore the project structure,
then summarize what files exist.
"""
# Create toolset with Shell tool
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个项目分析助手。使用 Shell 工具执行命令来探索项目结构。"
"请使用 find、ls、tree 等命令来了解项目。"
),
max_iterations=5, # Limit iterations to prevent hanging
)
# Add overall timeout to prevent infinite hanging
try:
response = await asyncio.wait_for(
agent.run(
f"探索项目目录 {test_project} 的结构,列出所有文件和目录,并总结项目的组织方式。"
),
timeout=120.0, # 2 minute overall timeout
)
except asyncio.TimeoutError:
pytest.fail("Agent timed out after 120 seconds - possible infinite loop")
assert response, "Agent should return a response"
print(f"\n[项目结构探索结果]:\n{response}\n")
# Verify response mentions key files
response_lower = response.lower()
assert any(
word in response_lower for word in ["src", "tests", "main.py", "helper", "logger"]
), "Response should mention project files"
@pytest.mark.asyncio
async def test_search_and_analyze_code(self, real_provider, test_project):
"""Test searching for patterns and analyzing code.
Task: Use grep/glob to find specific patterns,
read the files, and create an analysis report.
"""
# Create toolset with all file tools
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
file_tools=ToolSuiteConfig().file_tools,
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个代码分析助手。使用 Glob、Grep、ReadFile 等工具来搜索和分析代码。"
"请使用 Shell 工具执行 grep、find 等命令。"
),
)
response = await agent.run(
f"在项目 {test_project} 中搜索所有包含 'def ' 的 Python 文件,"
f"列出找到的函数定义,并创建一个函数清单文件保存到 {test_project}/functions.txt。"
)
assert response, "Agent should return a response"
print(f"\n[代码搜索分析结果]:\n{response}\n")
# Check if analysis file was created
functions_file = test_project / "functions.txt"
if functions_file.exists():
content = functions_file.read_text()
print(f"\n[函数清单文件]:\n{content}\n")
assert len(content) > 0, "Functions file should not be empty"
@pytest.mark.asyncio
async def test_complex_multi_step_task(self, real_provider, test_project):
"""Test a complex multi-step CLI task.
Task:
1. Find all Python files using shell
2. Search for TODO comments using grep
3. Read files with TODOs
4. Create a summary report
"""
# Add some TODO comments
todo_file = test_project / "src" / "utils" / "todo_items.py"
todo_file.write_text('''"""Module with TODO items."""
# TODO: Implement error handling
def risky_operation(data):
"""Perform a risky operation."""
return data / 0 # This will fail
# TODO: Add caching mechanism
def expensive_computation(n):
"""Perform expensive computation."""
return sum(range(n))
# FIXME: Memory leak in this function
def process_large_file(path):
"""Process a large file."""
with open(path) as f:
return f.read()
''')
# Create comprehensive toolset
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
file_tools=ToolSuiteConfig().file_tools,
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个项目维护助手。"
"使用 Shell 工具执行命令(如 find、grep、ls 等)。"
"使用 ReadFile 读取文件内容。"
"使用 WriteFile 创建新文件。"
"请一步一步完成任务。"
),
)
response = await agent.run(
f"请完成以下任务:\n"
f"1. 使用 'find' 命令找出项目 {test_project} 中所有的 .py 文件\n"
f"2. 使用 'grep' 命令搜索所有包含 'TODO''FIXME' 的行\n"
f"3. 读取包含 TODO 的文件内容\n"
f"4. 创建一个 TODO 报告文件,保存到 {test_project}/todo_report.txt"
)
assert response, "Agent should return a response"
print(f"\n[复杂任务结果]:\n{response}\n")
# Verify report was created
report_file = test_project / "todo_report.txt"
if report_file.exists():
content = report_file.read_text()
print(f"\n[TODO 报告]:\n{content}\n")
@pytest.mark.asyncio
async def test_shell_pipes_and_chains(self, real_provider, test_project):
"""Test complex shell commands with pipes and chains.
Task: Use shell pipes to perform complex data processing.
"""
toolset = ConfigurableToolset(
config=ToolSuiteConfig(
shell_tools=ToolSuiteConfig().shell_tools,
),
work_dir=str(test_project),
)
agent = Agent(
provider=real_provider,
tools=toolset.tools,
system_prompt=(
"你是一个 Shell 命令专家。"
"使用复杂的 Shell 命令(管道、重定向、条件执行等)来完成任务。"
),
)
response = await agent.run(
f"在项目目录 {test_project} 中执行以下操作:\n"
f"1. 使用 'find . -name \"*.py\" | xargs wc -l' 统计所有 Python 文件的总行数\n"
f'2. 使用 \'grep -r "def " --include="*.py" | wc -l\' 统计函数定义数量\n'
f"3. 使用 'ls -la' 查看目录详情\n"
f"报告你的发现。"
)
assert response, "Agent should return a response"
print(f"\n[Shell 管道命令结果]:\n{response}\n")
# Verify response contains relevant information
response_lower = response.lower()
assert any(
word in response_lower for word in ["", "line", "函数", "function", "文件", "file"]
), "Response should mention analysis results"

View File

@@ -0,0 +1,374 @@
"""End-to-end scenario test for file operations.
This test simulates a realistic scenario where an agent:
1. Reads a file
2. Explains its content
3. Creates a new file with analysis results
This is a meaningful e2e test that demonstrates the agent's ability to
orchestrate multiple tool calls in sequence.
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, tool
# =============================================================================
# File Operation Tools
# =============================================================================
@tool()
async def read_file(file_path: str) -> str:
"""Read the content of a file.
Args:
file_path: Path to the file to read.
Returns:
The content of the file as a string.
Raises:
FileNotFoundError: If the file does not exist.
"""
with open(file_path) as f:
return f.read()
@tool()
async def write_file(file_path: str, content: str) -> str:
"""Write content to a file, creating it if it doesn't exist.
Args:
file_path: Path to the file to write.
content: Content to write to the file.
Returns:
Success message confirming the file was written.
"""
# Create parent directories if they don't exist
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(content)
return f"File successfully written to {file_path}"
@tool()
async def list_files(directory: str) -> str:
"""List all files in a directory.
Args:
directory: Path to the directory to list.
Returns:
A newline-separated list of file names in the directory.
"""
files = os.listdir(directory)
return "\n".join(files)
# =============================================================================
# E2E Test
# =============================================================================
@pytest.mark.scenario
class TestFileOperationsScenario:
"""End-to-end test for file read/write operations."""
@pytest.mark.asyncio
async def test_read_explain_and_write(self, mock_provider):
"""Test a complete workflow: read file -> explain -> write results."""
# Setup: Create a temporary file with content
with tempfile.TemporaryDirectory() as tmpdir:
# Create a source file to read
source_file = os.path.join(tmpdir, "source.txt")
source_content = """Project Overview
================
This is a sample project document for testing.
Features:
- Feature A: Does something useful
- Feature B: Does something else
- Feature C: The most important feature
Conclusion: This project demonstrates file operations.
"""
with open(source_file, "w") as f:
f.write(source_content)
# Configure mock provider responses
# The agent should:
# 1. Read the file
# 2. Summarize it
# 3. Write the summary to a new file
mock_provider.add_text_response(
f"I'll read the file at {source_file} and analyze it for you."
)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a helpful file analysis assistant.",
)
# Step 1: Agent reads and analyzes the file
mock_provider.clear_responses()
mock_provider.add_tool_call(
"read_file",
{"file_path": source_file},
source_content,
)
# Agent analyzes the content
mock_provider.add_text_response(
"I've read the file. It's a project overview document with 3 features. "
"Let me create a summary file."
)
# Step 2: Agent writes summary to a new file
summary_file = os.path.join(tmpdir, "summary.txt")
expected_summary = """Project Summary
================
This is a sample project with 3 main features:
- Feature A, - Feature B, - Feature C
The most important feature is Feature C.
"""
mock_provider.clear_responses()
mock_provider.add_tool_call(
"write_file",
{
"file_path": summary_file,
"content": expected_summary,
},
f"File successfully written to {summary_file}",
)
mock_provider.add_text_response(f"I've created a summary at {summary_file}")
# Execute the agent
response = await agent.run(
f"Please read {source_file}, analyze it, and create a summary file at {summary_file}"
)
# Verify the interaction
assert "summary" in response.lower()
# Verify the provider was called correctly
assert len(mock_provider.calls) >= 1
@pytest.mark.asyncio
async def test_list_files_scenario(self, mock_provider):
"""Test listing files in a directory."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create some test files
for i in range(3):
with open(os.path.join(tmpdir, f"file{i}.txt"), "w") as f:
f.write(f"Content {i}")
# Configure agent to list files
mock_provider.add_tool_call(
"list_files",
{"directory": tmpdir},
"file0.txt\nfile1.txt\nfile2.txt",
)
mock_provider.add_text_response(
f"I found 3 files in {tmpdir}: file0.txt, file1.txt, file2.txt"
)
agent = Agent(
provider=mock_provider,
tools=[list_files],
system_prompt="You are a file system assistant.",
)
response = await agent.run(f"List all files in {tmpdir}")
assert "3 files" in response
@pytest.mark.asyncio
async def test_multi_step_file_workflow(self, mock_provider):
"""Test a complex multi-step file workflow.
Scenario:
1. List files in directory
2. Read each file
3. Create a combined report
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Create test files
files_content = {
"report1.txt": "Sales increased by 20%",
"report2.txt": "Customer satisfaction at 85%",
"report3.txt": "Bug fixes: 15 resolved",
}
for name, content in files_content.items():
with open(os.path.join(tmpdir, name), "w") as f:
f.write(content)
# Configure agent responses for multi-step workflow
tools = [read_file, write_file, list_files]
# Step 1: List files
mock_provider.add_tool_call(
"list_files",
{"directory": tmpdir},
"report1.txt\nreport2.txt\nreport3.txt",
)
# Step 2: Read all files
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report1.txt")},
"Sales increased by 20%",
)
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report2.txt")},
"Customer satisfaction at 85%",
)
mock_provider.add_tool_call(
"read_file",
{"file_path": os.path.join(tmpdir, "report3.txt")},
"Bug fixes: 15 resolved",
)
# Step 3: Write combined report
combined_report = """Combined Report
================
1. Sales: Increased by 20%
2. Customer Satisfaction: 85%
3. Development: 15 bugs resolved
"""
mock_provider.add_tool_call(
"write_file",
{
"file_path": os.path.join(tmpdir, "combined_report.txt"),
"content": combined_report,
},
f"File successfully written to {os.path.join(tmpdir, 'combined_report.txt')}",
)
mock_provider.add_text_response(
"I've created a combined report summarizing all three reports."
)
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a report analyst assistant.",
)
response = await agent.run(
f"List all files in {tmpdir}, read them all, and create a combined report at combined_report.txt"
)
assert "combined report" in response.lower()
# =============================================================================
# Additional Tools for Extended Scenarios
# =============================================================================
@tool()
async def count_words(file_path: str) -> str:
"""Count the number of words in a file.
Args:
file_path: Path to the file to analyze.
Returns:
The word count as a string.
"""
with open(file_path) as f:
content = f.read()
word_count = len(content.split())
return f"Word count: {word_count}"
@tool()
async def append_to_file(file_path: str, content: str) -> str:
"""Append content to an existing file.
Args:
file_path: Path to the file to append to.
content: Content to append.
Returns:
Success message.
"""
with open(file_path, "a") as f:
f.write("\n" + content)
return f"Content appended to {file_path}"
@pytest.mark.scenario
class TestExtendedFileOperations:
"""Extended scenarios with more file operations."""
@pytest.mark.asyncio
async def test_read_count_and_append(self, mock_provider):
"""Test reading a file, counting words, and appending a note."""
with tempfile.TemporaryDirectory() as tmpdir:
source_file = os.path.join(tmpdir, "document.txt")
with open(source_file, "w") as f:
f.write("This is a test document with several words in it.")
tools = [read_file, write_file, count_words, append_to_file]
# Step 1: Read file
mock_provider.add_tool_call(
"read_file",
{"file_path": source_file},
"This is a test document with several words in it.",
)
# Step 2: Count words
mock_provider.add_tool_call(
"count_words",
{"file_path": source_file},
"Word count: 10",
)
# Step 3: Append analysis
mock_provider.add_tool_call(
"append_to_file",
{
"file_path": source_file,
"content": "\n\n[Analysis] This document contains 10 words.",
},
f"Content appended to {source_file}",
)
mock_provider.add_text_response(
"I've analyzed the document and appended the word count analysis."
)
agent = Agent(
provider=mock_provider,
tools=tools,
system_prompt="You are a document analysis assistant.",
)
response = await agent.run(
f"Read {source_file}, count its words, and append the word count as an analysis note"
)
assert "analyzed" in response.lower()

View File

@@ -0,0 +1,226 @@
"""End-to-end scenario test for file operations with real API.
This test simulates a realistic scenario where an agent:
1. Reads a file
2. Explains its content
3. Creates a new file with analysis results
Uses real SiliconFlow qwen3.5-397B API (requires SILICONFLOW_API_KEY env var).
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
import pytest
from agentlite import Agent, OpenAIProvider, tool
# =============================================================================
# Configuration from model_config.toml
# =============================================================================
# SiliconFlow API configuration (matches qwen35_397b in model_config.toml)
SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
SILICONFLOW_MODEL = "Qwen/Qwen3.5-397B-A17B"
def get_siliconflow_provider() -> OpenAIProvider | None:
"""Create OpenAIProvider for SiliconFlow API.
Returns None if SILICONFLOW_API_KEY is not set.
"""
api_key = os.environ.get("SILICONFLOW_API_KEY")
if not api_key:
return None
return OpenAIProvider(
api_key=api_key,
base_url=SILICONFLOW_BASE_URL,
model=SILICONFLOW_MODEL,
)
# =============================================================================
# File Operation Tools
# =============================================================================
@tool()
async def read_file(file_path: str) -> str:
"""Read the content of a file.
Args:
file_path: Path to the file to read.
Returns:
The content of the file as a string.
Raises:
FileNotFoundError: If the file does not exist.
"""
with open(file_path) as f:
return f.read()
@tool()
async def write_file(file_path: str, content: str) -> str:
"""Write content to a file, creating it if it doesn't exist.
Args:
file_path: Path to the file to write.
content: Content to write to the file.
Returns:
Success message confirming the file was written.
"""
# Create parent directories if they don't exist
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(content)
return f"File successfully written to {file_path}"
@tool()
async def list_files(directory: str) -> str:
"""List all files in a directory.
Args:
directory: Path to the directory to list.
Returns:
A newline-separated list of file names in the directory.
"""
files = os.listdir(directory)
return "\n".join(files)
# =============================================================================
# Real API E2E Tests
# =============================================================================
@pytest.fixture
def real_provider():
"""Create a real SiliconFlow provider.
Skip tests if SILICONFLOW_API_KEY is not set.
"""
provider = get_siliconflow_provider()
if provider is None:
pytest.skip("SILICONFLOW_API_KEY not set, skipping real API tests")
return provider
@pytest.mark.scenario
@pytest.mark.expensive
class TestFileOperationsWithRealAPI:
"""End-to-end tests with real SiliconFlow API."""
@pytest.mark.asyncio
async def test_read_and_summarize(self, real_provider):
"""Test reading a file and creating a summary with real API."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a source file with meaningful content
source_file = os.path.join(tmpdir, "source.txt")
source_content = """AgentLite 项目概述
================
AgentLite 是一个轻量级的 Agent 组件库,主要特点:
- 异步优先设计
- OpenAI 兼容 API
- 工具系统 (支持 MCP)
- 流式响应支持
使用示例:
```python
from agentlite import Agent, OpenAIProvider
provider = OpenAIProvider(api_key="...", model="gpt-4")
agent = Agent(provider=provider)
response = await agent.run("Hello!")
```
"""
with open(source_file, "w") as f:
f.write(source_content)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=real_provider,
tools=tools,
system_prompt="你是一个文件分析助手。请使用工具来完成任务。",
)
# Run the agent to read, analyze, and write summary
output_file = os.path.join(tmpdir, "summary.txt")
response = await agent.run(
f"请读取 {source_file} 文件,分析其内容,并创建一个摘要文件保存到 {output_file}"
)
# Verify the agent responded
assert response, "Agent should return a response"
print(f"\n[Agent 响应]:\n{response}\n")
# Verify the output file was created
if os.path.exists(output_file):
with open(output_file) as f:
output_content = f.read()
print(f"\n[输出文件内容]:\n{output_content}\n")
assert len(output_content) > 0, "Output file should not be empty"
@pytest.mark.asyncio
async def test_list_files_and_combine(self, real_provider):
"""Test listing files, reading them, and creating combined report."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create multiple files
files = {
"sales.txt": "销售额增长了 20%",
"users.txt": "用户满意度达到 85%",
"bugs.txt": "修复了 15 个问题",
}
for name, content in files.items():
with open(os.path.join(tmpdir, name), "w") as f:
f.write(content)
# Create agent with file tools
tools = [read_file, write_file, list_files]
agent = Agent(
provider=real_provider,
tools=tools,
system_prompt="你是一个数据分析助手。请使用工具来完成任务。",
)
# Run the agent
report_file = os.path.join(tmpdir, "report.txt")
response = await agent.run(
f"列出 {tmpdir} 目录中的所有文件,读取每个文件的内容,然后创建一份综合报告保存到 {report_file}"
)
# Verify the agent responded
assert response, "Agent should return a response"
print(f"\n[Agent 响应]:\n{response}\n")
# The agent should have created the report file
if os.path.exists(report_file):
with open(report_file) as f:
report_content = f.read()
print(f"\n[报告文件内容]:\n{report_content}\n")
@pytest.mark.asyncio
async def test_simple_conversation(self, real_provider):
"""Test basic conversation without tools."""
agent = Agent(
provider=real_provider,
system_prompt="你是一个有帮助的助手。请用中文回答。",
)
response = await agent.run("你好!请简单介绍一下你自己。")
assert response, "Agent should return a response"
print(f"\n[Agent 自我介绍]:\n{response}\n")
assert len(response) > 10, "Response should be meaningful"

View File

@@ -0,0 +1,521 @@
"""工具专项测试 - 文档提取和知识图谱工具
本模块测试基于数据基底的工具功能,包括:
1. 文档读取和解析工具
2. 实体提取工具
3. 知识图谱查询工具
4. 推理工具
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import pytest
import yaml
from agentlite import Agent, tool
def tool_output(result: Any) -> Any:
"""兼容旧式返回值和 ToolResult 返回值."""
return getattr(result, "output", result)
# =============================================================================
# 数据加载 fixtures
# =============================================================================
@pytest.fixture
def data_dir() -> Path:
"""返回测试数据目录路径."""
return Path(__file__).parent.parent / "data"
@pytest.fixture
def sample_article(data_dir: Path) -> str:
"""加载样例文章."""
return (data_dir / "documents" / "sample_article.md").read_text(encoding="utf-8")
@pytest.fixture
def technical_spec(data_dir: Path) -> str:
"""加载技术规范文档."""
return (data_dir / "documents" / "technical_spec.md").read_text(encoding="utf-8")
@pytest.fixture
def meeting_notes(data_dir: Path) -> str:
"""加载会议记录."""
return (data_dir / "documents" / "meeting_notes.txt").read_text(encoding="utf-8")
@pytest.fixture
def knowledge_graph_entities(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱实体数据."""
with open(data_dir / "knowledge_base" / "entities.json") as f:
return json.load(f)
@pytest.fixture
def knowledge_graph_relations(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱关系数据."""
with open(data_dir / "knowledge_base" / "relations.json") as f:
return json.load(f)
@pytest.fixture
def graph_queries(data_dir: Path) -> list[dict[str, Any]]:
"""加载图谱查询测试用例."""
with open(data_dir / "knowledge_base" / "graph_queries.yaml") as f:
data = yaml.safe_load(f)
return data.get("queries", [])
# =============================================================================
# 知识图谱工具实现
# =============================================================================
class KnowledgeGraph:
"""知识图谱内存存储."""
def __init__(self, entities: list[dict], relations: list[dict]):
self._entities = {e["id"]: e for e in entities}
self._relations = relations
self._index_by_type: dict[str, list[str]] = {}
self._index_by_name: dict[str, str] = {}
# 构建索引
for entity_id, entity in self._entities.items():
entity_type = entity.get("type", "Unknown")
if entity_type not in self._index_by_type:
self._index_by_type[entity_type] = []
self._index_by_type[entity_type].append(entity_id)
name = entity.get("name", "")
if name:
self._index_by_name[name] = entity_id
def get_entity(self, entity_id: str) -> dict | None:
"""获取实体."""
return self._entities.get(entity_id)
def get_entity_by_name(self, name: str) -> dict | None:
"""通过名称获取实体."""
entity_id = self._index_by_name.get(name)
if entity_id:
return self._entities.get(entity_id)
return None
def get_entities_by_type(self, entity_type: str) -> list[dict]:
"""获取特定类型的所有实体."""
entity_ids = self._index_by_type.get(entity_type, [])
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
def get_relations(
self, from_id: str | None = None, to_id: str | None = None, relation_type: str | None = None
) -> list[dict]:
"""获取关系."""
results = []
for rel in self._relations:
if from_id and rel.get("from") != from_id:
continue
if to_id and rel.get("to") != to_id:
continue
if relation_type and rel.get("type") != relation_type:
continue
results.append(rel)
return results
def get_neighbors(self, entity_id: str, relation_type: str | None = None) -> list[dict]:
"""获取邻居实体."""
relations = self.get_relations(from_id=entity_id, relation_type=relation_type)
neighbors = []
for rel in relations:
target_id = rel.get("to")
if target_id and target_id in self._entities:
neighbors.append({"entity": self._entities[target_id], "relation": rel})
return neighbors
def find_path(self, start_id: str, end_id: str, max_depth: int = 3) -> list[list[str]] | None:
"""查找两个实体之间的路径."""
if start_id == end_id:
return [[start_id]]
if max_depth <= 0:
return None
# BFS
from collections import deque
queue = deque([(start_id, [start_id])])
visited = {start_id}
all_paths = []
while queue:
current_id, path = queue.popleft()
if len(path) > max_depth + 1:
continue
relations = self.get_relations(from_id=current_id)
for rel in relations:
next_id = rel.get("to")
if not next_id:
continue
new_path = path + [next_id]
if next_id == end_id:
all_paths.append(new_path)
elif next_id not in visited and len(new_path) <= max_depth:
visited.add(next_id)
queue.append((next_id, new_path))
return all_paths if all_paths else None
@pytest.fixture
def knowledge_graph(knowledge_graph_entities, knowledge_graph_relations) -> KnowledgeGraph:
"""创建知识图谱实例."""
return KnowledgeGraph(
entities=knowledge_graph_entities.get("entities", []),
relations=knowledge_graph_relations.get("relations", []),
)
# =============================================================================
# 工具定义
# =============================================================================
@tool()
async def read_document(file_path: str) -> str:
"""读取文档内容.
Args:
file_path: 文档路径
Returns:
文档内容
"""
path = Path(file_path)
if not path.exists():
return f"Error: File not found: {file_path}"
try:
return path.read_text(encoding="utf-8")
except Exception as e:
return f"Error reading file: {e}"
@tool()
async def extract_entities(text: str) -> str:
"""从文本中提取实体.
Args:
text: 输入文本
Returns:
JSON 格式的实体列表
"""
# 简化的实体提取 - 实际应使用 NLP 模型
import re
entities = []
# 提取人名(简单的中文姓名匹配)
person_pattern = r"[\u4e00-\u9fa5]{2,4}"
potential_names = re.findall(person_pattern, text)
common_names = ["张三", "李四", "王五", "赵六", "李飞飞", "吴恩达", "Yann LeCun"]
for name in potential_names:
if name in common_names or len(name) == 3:
entities.append({"type": "Person", "name": name})
# 提取公司/组织名
org_pattern = r"(TechCorp|OpenAI|GitHub|Google)"
orgs = re.findall(org_pattern, text)
for org in set(orgs):
entities.append({"type": "Organization", "name": org})
# 提取技术术语
tech_pattern = r"(Python|TensorFlow|PyTorch|GPT-4|AI|LLM)"
techs = re.findall(tech_pattern, text)
for tech in set(techs):
entities.append({"type": "Technology", "name": tech})
return json.dumps(entities, ensure_ascii=False)
@tool()
async def query_knowledge_graph(query_type: str, params: str) -> str:
"""查询知识图谱.
Args:
query_type: 查询类型 (person_relations, company_employees, technology_users, etc.)
params: JSON 格式的查询参数
Returns:
查询结果
"""
# 这里使用全局的 kg 实例,实际应在 Agent 初始化时注入
try:
params_dict = json.loads(params)
except json.JSONDecodeError:
return json.dumps({"error": "Invalid JSON params"})
# 简化实现 - 实际应基于知识图谱查询
result = {"query_type": query_type, "params": params_dict, "results": []}
return json.dumps(result, ensure_ascii=False)
@tool()
async def reason_about_path(start_entity: str, end_entity: str) -> str:
"""推理两个实体之间的关系路径.
Args:
start_entity: 起始实体名称
end_entity: 目标实体名称
Returns:
推理结果
"""
return json.dumps(
{
"start": start_entity,
"end": end_entity,
"reasoning": f"分析 {start_entity}{end_entity} 的关系链...",
"path": [],
},
ensure_ascii=False,
)
# =============================================================================
# 测试用例
# =============================================================================
@pytest.mark.tools
class TestDocumentTools:
"""文档工具测试."""
@pytest.mark.asyncio
async def test_read_document(self, data_dir: Path, sample_article: str):
"""测试文档读取工具."""
result = tool_output(await read_document(str(data_dir / "documents" / "sample_article.md")))
assert "人工智能" in result
assert "GitHub Copilot" in result
assert "张三" in result
@pytest.mark.asyncio
async def test_read_document_not_found(self):
"""测试读取不存在的文档."""
result = tool_output(await read_document("/nonexistent/file.md"))
assert "Error" in result
assert "not found" in result.lower()
@pytest.mark.asyncio
async def test_extract_entities_from_article(self, sample_article: str):
"""测试从文章中提取实体."""
result = tool_output(await extract_entities(sample_article))
entities = json.loads(result)
# 验证提取到实体
assert len(entities) > 0
# 验证实体类型
entity_names = [e["name"] for e in entities]
assert "张三" in entity_names
assert "TechCorp" in entity_names or "OpenAI" in entity_names
@pytest.mark.tools
class TestKnowledgeGraphTools:
"""知识图谱工具测试."""
def test_knowledge_graph_initialization(self, knowledge_graph: KnowledgeGraph):
"""测试知识图谱初始化."""
# 验证实体数量
entity = knowledge_graph.get_entity_by_name("张三")
assert entity is not None
assert entity["type"] == "Person"
# 验证公司实体
company = knowledge_graph.get_entity_by_name("TechCorp")
assert company is not None
assert company["type"] == "Company"
def test_get_entities_by_type(self, knowledge_graph: KnowledgeGraph):
"""测试按类型获取实体."""
persons = knowledge_graph.get_entities_by_type("Person")
assert len(persons) >= 3 # 张三、李四、李飞飞
technologies = knowledge_graph.get_entities_by_type("Technology")
assert len(technologies) >= 2 # Python、OpenAI API
def test_get_relations(self, knowledge_graph: KnowledgeGraph):
"""测试获取关系."""
# 获取张三的所有关系
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
relations = knowledge_graph.get_relations(from_id=zhangsan["id"])
assert len(relations) >= 2 # works_for, uses
# 验证关系类型
relation_types = [r["type"] for r in relations]
assert "works_for" in relation_types
assert "uses" in relation_types
def test_get_neighbors(self, knowledge_graph: KnowledgeGraph):
"""测试获取邻居节点."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
neighbors = knowledge_graph.get_neighbors(zhangsan["id"])
assert len(neighbors) >= 2
# 验证邻居包含 TechCorp
neighbor_names = [n["entity"]["name"] for n in neighbors]
assert "TechCorp" in neighbor_names
def test_find_path(self, knowledge_graph: KnowledgeGraph):
"""测试查找路径."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
techcorp = knowledge_graph.get_entity_by_name("TechCorp")
assert zhangsan is not None
assert techcorp is not None
paths = knowledge_graph.find_path(zhangsan["id"], techcorp["id"])
assert paths is not None
assert len(paths) > 0
# 验证路径长度
first_path = paths[0]
assert len(first_path) == 2 # 张三 -> TechCorp
@pytest.mark.asyncio
async def test_query_knowledge_graph(self):
"""测试知识图谱查询工具."""
params = json.dumps({"entity_name": "张三"})
result = tool_output(await query_knowledge_graph("person_relations", params))
data = json.loads(result)
assert data["query_type"] == "person_relations"
assert "params" in data
@pytest.mark.asyncio
async def test_reason_about_path(self):
"""测试路径推理工具."""
result = tool_output(await reason_about_path("张三", "OpenAI"))
data = json.loads(result)
assert data["start"] == "张三"
assert data["end"] == "OpenAI"
assert "reasoning" in data
@pytest.mark.tools
class TestDataIntegrity:
"""数据完整性测试."""
def test_entities_json_valid(self, knowledge_graph_entities: dict):
"""验证实体 JSON 格式正确."""
assert "entities" in knowledge_graph_entities
entities = knowledge_graph_entities["entities"]
assert len(entities) > 0
# 验证每个实体都有必需的字段
for entity in entities:
assert "id" in entity
assert "type" in entity
assert "name" in entity
def test_relations_json_valid(
self, knowledge_graph_relations: dict, knowledge_graph_entities: dict
):
"""验证关系 JSON 格式正确且引用的实体存在."""
assert "relations" in knowledge_graph_relations
relations = knowledge_graph_relations["relations"]
entity_ids = {e["id"] for e in knowledge_graph_entities["entities"]}
for relation in relations:
assert "from" in relation
assert "to" in relation
assert "type" in relation
# 验证引用的实体存在
assert relation["from"] in entity_ids, f"Entity {relation['from']} not found"
assert relation["to"] in entity_ids, f"Entity {relation['to']} not found"
def test_graph_queries_yaml_valid(self, graph_queries: list):
"""验证查询 YAML 格式正确."""
assert len(graph_queries) > 0
for query in graph_queries:
assert "id" in query
assert "description" in query
assert "query" in query
assert "expected_results" in query
def test_documents_exist(self, data_dir: Path):
"""验证测试文档存在且非空."""
docs_dir = data_dir / "documents"
sample_article = docs_dir / "sample_article.md"
assert sample_article.exists()
assert sample_article.stat().st_size > 0
tech_spec = docs_dir / "technical_spec.md"
assert tech_spec.exists()
assert tech_spec.stat().st_size > 0
meeting_notes = docs_dir / "meeting_notes.txt"
assert meeting_notes.exists()
assert meeting_notes.stat().st_size > 0
@pytest.mark.tools
class TestAgentWithTools:
"""Agent 集成工具测试."""
@pytest.mark.asyncio
async def test_agent_with_document_tools(self, mock_provider, data_dir: Path):
"""测试带有文档工具的 Agent."""
mock_provider.add_text_response("我已经读取了文档")
agent = Agent(
provider=mock_provider,
tools=[read_document],
system_prompt="你是一个文档助手,可以读取和分析文档。",
)
response = await agent.run(f"请读取文档 {data_dir / 'documents' / 'sample_article.md'}")
assert "文档" in response or "读取" in response
@pytest.mark.asyncio
async def test_agent_with_kg_tools(self, mock_provider):
"""测试带有知识图谱工具的 Agent."""
mock_provider.add_text_response("张三在 TechCorp 工作")
agent = Agent(
provider=mock_provider,
tools=[query_knowledge_graph, reason_about_path],
system_prompt="你是一个知识图谱助手,可以查询实体关系。",
)
response = await agent.run("张三在哪里工作?")
assert response is not None
assert len(response) > 0

View File

View File

@@ -0,0 +1,330 @@
"""Unit tests for configuration models.
This module tests all Pydantic configuration models including
ProviderConfig, ModelConfig, ToolConfig, and AgentConfig.
"""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from agentlite import ProviderConfig, ModelConfig, AgentConfig
class TestProviderConfig:
"""Tests for ProviderConfig."""
def test_provider_config_valid(self):
"""Test valid ProviderConfig creation."""
config = ProviderConfig(
type="openai",
base_url="https://api.openai.com/v1",
api_key="sk-test123",
)
assert config.type == "openai"
assert config.base_url == "https://api.openai.com/v1"
assert config.api_key.get_secret_value() == "sk-test123"
def test_provider_config_default_type(self):
"""Test ProviderConfig with default type."""
config = ProviderConfig(
base_url="https://api.openai.com/v1",
api_key="sk-test",
)
assert config.type == "openai"
def test_provider_config_default_url(self):
"""Test ProviderConfig with default base_url."""
config = ProviderConfig(
api_key="sk-test",
)
assert config.base_url == "https://api.openai.com/v1"
def test_provider_config_invalid_url_http(self):
"""Test ProviderConfig with invalid URL scheme."""
with pytest.raises(ValidationError) as exc_info:
ProviderConfig(
type="openai",
base_url="ftp://invalid.com",
api_key="sk-test",
)
assert "base_url must start with http:// or https://" in str(exc_info.value)
def test_provider_config_invalid_url_no_scheme(self):
"""Test ProviderConfig with URL without scheme."""
with pytest.raises(ValidationError):
ProviderConfig(
base_url="api.openai.com/v1",
api_key="sk-test",
)
def test_provider_config_custom_headers(self):
"""Test ProviderConfig with custom headers."""
config = ProviderConfig(
api_key="sk-test",
headers={"X-Custom": "value"},
)
assert config.headers == {"X-Custom": "value"}
def test_provider_config_default_headers(self):
"""Test ProviderConfig default headers."""
config = ProviderConfig(api_key="sk-test")
assert config.headers == {}
def test_provider_config_timeout(self):
"""Test ProviderConfig timeout."""
config = ProviderConfig(
api_key="sk-test",
timeout=30.0,
)
assert config.timeout == 30.0
def test_provider_config_default_timeout(self):
"""Test ProviderConfig default timeout."""
config = ProviderConfig(api_key="sk-test")
assert config.timeout == 60.0
def test_provider_config_api_key_is_secret_str(self):
"""Test that api_key is stored as SecretStr."""
config = ProviderConfig(api_key="sk-secret")
# SecretStr should not expose value in repr/str
assert "sk-secret" not in str(config.api_key)
# But can get value explicitly
assert config.api_key.get_secret_value() == "sk-secret"
class TestModelConfig:
"""Tests for ModelConfig."""
def test_model_config_valid(self):
"""Test valid ModelConfig creation."""
config = ModelConfig(
provider="openai",
model="gpt-4",
)
assert config.provider == "openai"
assert config.model == "gpt-4"
def test_model_config_with_all_fields(self):
"""Test ModelConfig with all optional fields."""
config = ModelConfig(
provider="openai",
model="gpt-4",
max_tokens=1000,
temperature=0.7,
top_p=0.9,
capabilities={"streaming", "tool_calling"},
)
assert config.max_tokens == 1000
assert config.temperature == 0.7
assert config.top_p == 0.9
assert config.capabilities == {"streaming", "tool_calling"}
def test_model_config_empty_provider(self):
"""Test ModelConfig with empty provider."""
with pytest.raises(ValidationError) as exc_info:
ModelConfig(
provider="",
model="gpt-4",
)
assert "provider must not be empty" in str(exc_info.value)
def test_model_config_temperature_bounds(self):
"""Test ModelConfig temperature validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=0.0)
assert config.temperature == 0.0
# Valid: 2.0
config = ModelConfig(provider="openai", model="gpt-4", temperature=2.0)
assert config.temperature == 2.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=-0.1)
# Invalid: > 2
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", temperature=2.1)
def test_model_config_top_p_bounds(self):
"""Test ModelConfig top_p validation bounds."""
# Valid: 0.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=0.0)
assert config.top_p == 0.0
# Valid: 1.0
config = ModelConfig(provider="openai", model="gpt-4", top_p=1.0)
assert config.top_p == 1.0
# Invalid: < 0
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=-0.1)
# Invalid: > 1
with pytest.raises(ValidationError):
ModelConfig(provider="openai", model="gpt-4", top_p=1.1)
def test_model_config_default_capabilities(self):
"""Test ModelConfig default capabilities."""
config = ModelConfig(provider="openai", model="gpt-4")
assert config.capabilities == set()
class TestAgentConfig:
"""Tests for AgentConfig."""
def test_agent_config_minimal(self):
"""Test AgentConfig with minimal required fields."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
assert config.name == "agent"
assert config.system_prompt == "You are a helpful assistant."
assert config.default_model == "default"
def test_agent_config_full(self):
"""Test AgentConfig with all fields."""
config = AgentConfig(
name="my_agent",
system_prompt="Custom system prompt",
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
max_history=50,
)
assert config.name == "my_agent"
assert config.system_prompt == "Custom system prompt"
assert config.default_model == "gpt4"
assert config.max_history == 50
def test_agent_config_missing_default_model(self):
"""Test AgentConfig with non-existent default_model."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="nonexistent",
)
assert "not found in models" in str(exc_info.value)
def test_agent_config_unknown_provider(self):
"""Test AgentConfig with model referencing unknown provider."""
with pytest.raises(ValidationError) as exc_info:
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="unknown", model="gpt-4")},
)
assert "unknown provider" in str(exc_info.value)
def test_agent_config_get_provider_config(self):
"""Test get_provider_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config("gpt4")
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_default(self):
"""Test get_provider_config with default model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
provider_config = config.get_provider_config()
assert provider_config.api_key.get_secret_value() == "sk-test"
def test_agent_config_get_provider_config_not_found(self):
"""Test get_provider_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_provider_config("nonexistent")
def test_agent_config_get_model_config(self):
"""Test get_model_config method."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config("gpt4")
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_default(self):
"""Test get_model_config with default."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"gpt4": ModelConfig(provider="openai", model="gpt-4")},
default_model="gpt4",
)
model_config = config.get_model_config()
assert model_config.model == "gpt-4"
def test_agent_config_get_model_config_not_found(self):
"""Test get_model_config with non-existent model."""
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
)
with pytest.raises(ValueError, match="Model 'nonexistent' not found"):
config.get_model_config("nonexistent")
def test_agent_config_multiple_providers(self):
"""Test AgentConfig with multiple providers."""
config = AgentConfig(
providers={
"openai": ProviderConfig(api_key="sk-openai"),
"anthropic": ProviderConfig(
type="anthropic",
base_url="https://api.anthropic.com/v1",
api_key="sk-anthropic",
),
},
models={
"default": ModelConfig(provider="openai", model="gpt-4"),
"claude": ModelConfig(provider="anthropic", model="claude-3"),
},
)
assert len(config.providers) == 2
assert len(config.models) == 2
def test_agent_config_max_history_validation(self):
"""Test max_history validation."""
# Valid: min=1
config = AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=1,
)
assert config.max_history == 1
# Invalid: 0
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=0,
)
# Invalid: negative
with pytest.raises(ValidationError):
AgentConfig(
providers={"openai": ProviderConfig(api_key="sk-test")},
models={"default": ModelConfig(provider="openai", model="gpt-4")},
max_history=-1,
)

View File

@@ -0,0 +1,297 @@
"""Unit tests for message types.
This module tests all message-related types including ContentPart,
Message, ToolCall, and their various subclasses.
"""
from __future__ import annotations
import pytest
from agentlite import (
ContentPart,
Message,
TextPart,
ImageURLPart,
AudioURLPart,
ToolCall,
ToolCallPart,
)
class TestContentPart:
"""Tests for ContentPart base class and registry."""
def test_content_part_registry_auto_registers_subclasses(self):
"""Test that ContentPart subclasses are auto-registered."""
# All defined subclasses should be in registry
assert "text" in ContentPart._ContentPart__content_part_registry
assert "image_url" in ContentPart._ContentPart__content_part_registry
assert "audio_url" in ContentPart._ContentPart__content_part_registry
def test_text_part_creation(self):
"""Test basic TextPart creation."""
part = TextPart(text="Hello, world!")
assert part.type == "text"
assert part.text == "Hello, world!"
def test_text_part_model_dump(self):
"""Test TextPart serialization."""
part = TextPart(text="Hello")
dumped = part.model_dump()
assert dumped == {"type": "text", "text": "Hello"}
def test_text_part_merge_success(self):
"""Test successful text merge during streaming."""
part1 = TextPart(text="Hello ")
part2 = TextPart(text="world!")
result = part1.merge_in_place(part2)
assert result is True
assert part1.text == "Hello world!"
def test_text_part_merge_failure(self):
"""Test merge failure with incompatible types."""
text_part = TextPart(text="Hello")
# Try to merge with non-TextPart
result = text_part.merge_in_place("not a part")
assert result is False
assert text_part.text == "Hello" # Unchanged
class TestImageURLPart:
"""Tests for ImageURLPart."""
def test_image_url_part_creation(self):
"""Test ImageURLPart creation."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.type == "image_url"
assert part.image_url.url == "https://example.com/image.png"
def test_image_url_part_with_detail(self):
"""Test ImageURLPart with detail parameter."""
part = ImageURLPart(
image_url=ImageURLPart.ImageURL(url="https://example.com/image.png", detail="high")
)
assert part.image_url.detail == "high"
def test_image_url_part_default_detail(self):
"""Test ImageURLPart default detail is None."""
part = ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png"))
assert part.image_url.detail is None
class TestAudioURLPart:
"""Tests for AudioURLPart."""
def test_audio_url_part_creation(self):
"""Test AudioURLPart creation."""
part = AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3"))
assert part.type == "audio_url"
assert part.audio_url.url == "https://example.com/audio.mp3"
class TestToolCall:
"""Tests for ToolCall."""
def test_tool_call_creation(self):
"""Test ToolCall creation."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1, "b": 2}')
)
assert call.type == "function"
assert call.id == "call_123"
assert call.function.name == "add"
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_with_part(self):
"""Test ToolCall merging with ToolCallPart."""
call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments='{"a": 1')
)
part = ToolCallPart(arguments_part=', "b": 2}')
result = call.merge_in_place(part)
assert result is True
assert call.function.arguments == '{"a": 1, "b": 2}'
def test_tool_call_merge_failure(self):
"""Test ToolCall merge failure with incompatible types."""
call = ToolCall(id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}"))
result = call.merge_in_place("not a part")
assert result is False
class TestToolCallPart:
"""Tests for ToolCallPart."""
def test_tool_call_part_creation(self):
"""Test ToolCallPart creation."""
part = ToolCallPart(arguments_part='{"a": 1}')
assert part.arguments_part == '{"a": 1}'
def test_tool_call_part_none(self):
"""Test ToolCallPart with None arguments."""
part = ToolCallPart(arguments_part=None)
assert part.arguments_part is None
def test_tool_call_part_merge(self):
"""Test ToolCallPart merging."""
part1 = ToolCallPart(arguments_part='{"a":')
part2 = ToolCallPart(arguments_part=" 1}")
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
def test_tool_call_part_merge_none(self):
"""Test ToolCallPart merge when self is None."""
part1 = ToolCallPart(arguments_part=None)
part2 = ToolCallPart(arguments_part='{"a": 1}')
result = part1.merge_in_place(part2)
assert result is True
assert part1.arguments_part == '{"a": 1}'
class TestMessage:
"""Tests for Message."""
def test_message_string_content_coercion(self):
"""Test that string content is coerced to TextPart."""
msg = Message(role="user", content="Hello!")
assert len(msg.content) == 1
assert isinstance(msg.content[0], TextPart)
assert msg.content[0].text == "Hello!"
def test_message_part_content(self):
"""Test Message with ContentPart content."""
part = TextPart(text="Hello!")
msg = Message(role="user", content=part)
assert len(msg.content) == 1
assert msg.content[0].text == "Hello!"
def test_message_list_content(self):
"""Test Message with list of ContentParts."""
parts = [TextPart(text="Hello"), TextPart(text=" world!")]
msg = Message(role="user", content=parts)
assert len(msg.content) == 2
def test_message_extract_text(self):
"""Test text extraction from message."""
msg = Message(role="user", content="Hello world!")
assert msg.extract_text() == "Hello world!"
def test_message_extract_text_with_separator(self):
"""Test text extraction with custom separator."""
parts = [TextPart(text="Hello"), TextPart(text="world!")]
msg = Message(role="user", content=parts)
assert msg.extract_text(sep=" ") == "Hello world!"
assert msg.extract_text(sep="-") == "Hello-world!"
def test_message_has_tool_calls_false(self):
"""Test has_tool_calls returns False when no tool calls."""
msg = Message(role="assistant", content="Hello!")
assert msg.has_tool_calls() is False
def test_message_has_tool_calls_true(self):
"""Test has_tool_calls returns True when tool calls present."""
tool_call = ToolCall(
id="call_123", function=ToolCall.FunctionBody(name="add", arguments="{}")
)
msg = Message(role="assistant", content="Let me calculate that.", tool_calls=[tool_call])
assert msg.has_tool_calls() is True
def test_message_has_tool_calls_empty_list(self):
"""Test has_tool_calls with empty tool_calls list."""
msg = Message(role="assistant", content="Hello!", tool_calls=[])
assert msg.has_tool_calls() is False
def test_message_tool_response(self):
"""Test message with tool response."""
msg = Message(role="tool", content="Result: 42", tool_call_id="call_123")
assert msg.role == "tool"
assert msg.tool_call_id == "call_123"
def test_message_serialization(self):
"""Test Message serialization with model_dump."""
msg = Message(role="user", content="Hello!")
dumped = msg.model_dump()
assert dumped["role"] == "user"
assert "content" in dumped
def test_message_all_roles(self):
"""Test Message creation with all valid roles."""
for role in ["system", "user", "assistant", "tool"]:
msg = Message(role=role, content="Test")
assert msg.role == role
class TestPolymorphicContentPart:
"""Tests for polymorphic ContentPart validation."""
def test_polymorphic_validation_text(self):
"""Test that text type validates to TextPart."""
data = {"type": "text", "text": "Hello"}
part = ContentPart.model_validate(data)
assert isinstance(part, TextPart)
assert part.text == "Hello"
def test_polymorphic_validation_image(self):
"""Test that image_url type validates to ImageURLPart."""
data = {"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}
part = ContentPart.model_validate(data)
assert isinstance(part, ImageURLPart)
assert part.image_url.url == "https://example.com/image.png"
def test_polymorphic_validation_unknown_type(self):
"""Test validation with unknown type raises error."""
data = {"type": "unknown_type", "content": "test"}
with pytest.raises(ValueError, match="Unknown content part type"):
ContentPart.model_validate(data)
def test_polymorphic_validation_no_type(self):
"""Test validation without type raises error."""
data = {"content": "test"}
with pytest.raises(ValueError):
ContentPart.model_validate(data)
class TestMessageEdgeCases:
"""Tests for edge cases in Message handling."""
def test_empty_string_content(self):
"""Test Message with empty string content."""
msg = Message(role="user", content="")
assert msg.content[0].text == ""
def test_message_with_name(self):
"""Test Message with name field."""
msg = Message(role="user", content="Hello", name="user1")
assert msg.name == "user1"
def test_message_history_isolation(self):
"""Test that history modifications don't affect original."""
msg = Message(role="user", content="Hello")
# Modify the content list
msg.content.append(TextPart(text="Extra"))
# Original should be modified (it's the same object)
assert len(msg.content) == 2

View File

@@ -0,0 +1,166 @@
"""Unit tests for provider protocol and exceptions.
This module tests the ChatProvider protocol, StreamedMessage protocol,
and all exception types.
"""
from __future__ import annotations
from agentlite.provider import (
TokenUsage,
ChatProviderError,
APIConnectionError,
APITimeoutError,
APIStatusError,
APIEmptyResponseError,
ChatProvider,
StreamedMessage,
)
class TestTokenUsage:
"""Tests for TokenUsage."""
def test_token_usage_creation(self):
"""Test TokenUsage creation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.input_tokens == 100
assert usage.output_tokens == 50
assert usage.cached_tokens == 0 # Default
def test_token_usage_with_cached(self):
"""Test TokenUsage with cached tokens."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
assert usage.cached_tokens == 20
def test_token_usage_total(self):
"""Test total token calculation."""
usage = TokenUsage(input_tokens=100, output_tokens=50)
assert usage.total == 150
def test_token_usage_total_with_cached(self):
"""Test total with cached tokens (not included in total)."""
usage = TokenUsage(input_tokens=100, output_tokens=50, cached_tokens=20)
# Total is input + output, cached is tracked separately
assert usage.total == 150
class TestChatProviderError:
"""Tests for ChatProviderError hierarchy."""
def test_base_error_creation(self):
"""Test base ChatProviderError creation."""
error = ChatProviderError("Something went wrong")
assert error.message == "Something went wrong"
assert str(error) == "Something went wrong"
def test_api_connection_error(self):
"""Test APIConnectionError creation."""
error = APIConnectionError("Connection failed")
assert isinstance(error, ChatProviderError)
assert error.message == "Connection failed"
def test_api_timeout_error(self):
"""Test APITimeoutError creation."""
error = APITimeoutError("Request timed out")
assert isinstance(error, ChatProviderError)
assert error.message == "Request timed out"
def test_api_status_error(self):
"""Test APIStatusError creation."""
error = APIStatusError(429, "Rate limit exceeded")
assert isinstance(error, ChatProviderError)
assert error.status_code == 429
assert error.message == "Rate limit exceeded"
def test_api_status_error_different_codes(self):
"""Test APIStatusError with different status codes."""
codes = [400, 401, 403, 404, 429, 500, 502, 503]
for code in codes:
error = APIStatusError(code, f"Error {code}")
assert error.status_code == code
def test_api_empty_response_error(self):
"""Test APIEmptyResponseError creation."""
error = APIEmptyResponseError("Empty response from API")
assert isinstance(error, ChatProviderError)
assert error.message == "Empty response from API"
def test_exception_hierarchy(self):
"""Test that all exceptions inherit from ChatProviderError."""
errors = [
APIConnectionError("test"),
APITimeoutError("test"),
APIStatusError(500, "test"),
APIEmptyResponseError("test"),
]
for error in errors:
assert isinstance(error, ChatProviderError)
class TestChatProviderProtocol:
"""Tests for ChatProvider protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that ChatProvider is runtime checkable."""
# ChatProvider should have @runtime_checkable
assert hasattr(ChatProvider, "__protocol_attrs__")
def test_mock_provider_implements_protocol(self, mock_provider):
"""Test that MockProvider implements ChatProvider."""
assert isinstance(mock_provider, ChatProvider)
def test_mock_provider_has_model_name(self, mock_provider):
"""Test that mock provider has model_name property."""
assert hasattr(mock_provider, "model_name")
assert isinstance(mock_provider.model_name, str)
def test_mock_provider_has_generate_method(self, mock_provider):
"""Test that mock provider has generate method."""
assert hasattr(mock_provider, "generate")
assert callable(mock_provider.generate)
class TestStreamedMessageProtocol:
"""Tests for StreamedMessage protocol."""
def test_protocol_is_runtime_checkable(self):
"""Test that StreamedMessage is runtime checkable."""
assert hasattr(StreamedMessage, "__protocol_attrs__")
def test_mock_streamed_message_implements_protocol(self):
"""Test that MockStreamedMessage implements StreamedMessage."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert isinstance(stream, StreamedMessage)
def test_streamed_message_has_id_property(self):
"""Test that streamed message has id property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "id")
assert stream.id == "mock-msg-123"
def test_streamed_message_has_usage_property(self):
"""Test that streamed message has usage property."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "usage")
assert stream.usage is not None
assert isinstance(stream.usage, TokenUsage)
def test_streamed_message_is_async_iterable(self):
"""Test that streamed message is async iterable."""
from tests.conftest import MockStreamedMessage
from agentlite import TextPart
stream = MockStreamedMessage([TextPart(text="Hello")])
assert hasattr(stream, "__aiter__")

View File

@@ -0,0 +1,209 @@
"""Unit tests for tool decorator and CallableTool.
This module tests the @tool() decorator and related tool functionality.
"""
from __future__ import annotations
import pytest
from agentlite.tool import tool, CallableTool, ToolOk, ToolError
class TestToolDecorator:
"""Tests for the @tool() decorator."""
def test_tool_decorator_basic(self):
"""Test basic tool decorator functionality."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert isinstance(add, CallableTool)
assert add.name == "add"
assert add.description == "Add two numbers."
assert add.parameters["type"] == "object"
assert "a" in add.parameters["properties"]
assert "b" in add.parameters["properties"]
assert add.parameters["properties"]["a"]["type"] == "number"
assert add.parameters["properties"]["b"]["type"] == "number"
assert add.parameters["required"] == ["a", "b"]
def test_tool_decorator_with_default_params(self):
"""Test tool decorator with default parameters."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
assert greet.name == "greet"
assert "name" in greet.parameters["required"]
assert "greeting" not in greet.parameters["required"]
def test_tool_decorator_custom_name(self):
"""Test tool decorator with custom name."""
@tool(name="custom_add")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.name == "custom_add"
def test_tool_decorator_custom_description(self):
"""Test tool decorator with custom description."""
@tool(description="Custom description")
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
assert add.description == "Custom description"
def test_tool_decorator_no_docstring(self):
"""Test tool decorator with no docstring."""
@tool()
async def no_doc(a: float) -> float:
return a
assert no_doc.description == "No description provided"
def test_tool_decorator_param_types(self):
"""Test tool decorator with various parameter types."""
@tool()
async def multi_types(
s: str,
i: int,
f: float,
b: bool,
) -> dict:
"""Multiple types."""
return {"s": s, "i": i, "f": f, "b": b}
props = multi_types.parameters["properties"]
assert props["s"]["type"] == "string"
assert props["i"]["type"] == "integer"
assert props["f"]["type"] == "number"
assert props["b"]["type"] == "boolean"
def test_tool_decorator_no_type_hints(self):
"""Test tool decorator with no type hints."""
@tool()
async def no_types(param) -> str:
"""No type hints."""
return str(param)
assert no_types.parameters["properties"]["param"]["type"] == "string"
class TestToolDecoratorExecution:
"""Tests for tool decorator execution."""
@pytest.mark.asyncio
async def test_tool_execution_success(self):
"""Test successful tool execution."""
@tool()
async def add(a: float, b: float) -> float:
"""Add two numbers."""
return a + b
result = await add(1.0, 2.0)
assert isinstance(result, ToolOk)
assert result.output == "3.0"
@pytest.mark.asyncio
async def test_tool_execution_error(self):
"""Test tool execution with error."""
@tool()
async def divide(a: float, b: float) -> float:
"""Divide two numbers."""
return a / b
result = await divide(1.0, 0.0)
assert isinstance(result, ToolError)
assert "division by zero" in result.message
@pytest.mark.asyncio
async def test_tool_execution_with_kwargs(self):
"""Test tool execution with keyword arguments."""
@tool()
async def greet(name: str, greeting: str = "Hello") -> str:
"""Greet someone."""
return f"{greeting}, {name}!"
result = await greet(name="World", greeting="Hi")
assert isinstance(result, ToolOk)
assert result.output == "Hi, World!"
class TestToolDecoratorMemorixBug:
"""Tests for the specific bug reported by Memorix project."""
def test_tool_decorator_memorix_case(self):
"""Test the exact case from Memorix bug report.
This test verifies that the @tool() decorator works correctly
with async functions that have string and float parameters.
"""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
assert isinstance(add_memory, CallableTool)
assert add_memory.name == "add_memory"
assert add_memory.description == "存储记忆"
# Check parameters schema
params = add_memory.parameters
assert params["type"] == "object"
assert "content" in params["properties"]
assert "importance" in params["properties"]
assert params["properties"]["content"]["type"] == "string"
assert params["properties"]["importance"]["type"] == "number"
# content is required (no default), importance is optional
assert "content" in params["required"]
assert "importance" not in params["required"]
@pytest.mark.asyncio
async def test_tool_decorator_memorix_execution(self):
"""Test execution of the Memorix case."""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok", "content": content, "importance": importance}
result = await add_memory("test content", 0.8)
assert isinstance(result, ToolOk)
assert "ok" in result.output
def test_tool_decorator_can_be_used_in_agent(self):
"""Test that decorated tools can be used with Agent.
This is an integration-style test to ensure the decorated tool
has all required attributes for Agent usage.
"""
@tool()
async def add_memory(content: str, importance: float = 0.5) -> dict:
"""存储记忆"""
return {"status": "ok"}
# Verify the tool has the base property required by Agent
assert hasattr(add_memory, "base")
base_tool = add_memory.base
assert base_tool.name == "add_memory"
assert base_tool.description == "存储记忆"
assert base_tool.parameters == add_memory.parameters

98
agentlite/tests/utils.py Normal file
View File

@@ -0,0 +1,98 @@
"""Test utilities and helpers for AgentLite tests.
This module provides utility functions and helpers used across test modules.
"""
from __future__ import annotations
import asyncio
from typing import Any, TypeVar
T = TypeVar("T")
async def run_async(coro: asyncio.Coroutine[Any, Any, T]) -> T:
"""Run an async coroutine and return the result.
This is a helper for tests that need to run async code synchronously.
Args:
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
return await coro
def run_sync(coro: asyncio.Coroutine[Any, Any, T]) -> T:
"""Run an async coroutine synchronously.
Args:
coro: The coroutine to run.
Returns:
The result of the coroutine.
"""
return asyncio.run(coro)
async def collect_stream(stream) -> list[Any]:
"""Collect all items from an async stream into a list.
Args:
stream: The async stream to collect from.
Returns:
List of all items from the stream.
"""
items = []
async for item in stream:
items.append(item)
return items
async def collect_stream_text(stream) -> str:
"""Collect all text from an async text stream.
Args:
stream: The async stream to collect from.
Returns:
Concatenated text from all items.
"""
from agentlite import TextPart
text_parts = []
async for item in stream:
if isinstance(item, TextPart):
text_parts.append(item.text)
elif isinstance(item, str):
text_parts.append(item)
return "".join(text_parts)
def create_tool_schema(
name: str,
description: str,
properties: dict[str, Any],
required: list[str] | None = None,
) -> dict[str, Any]:
"""Create a JSON schema for a tool.
Args:
name: Tool name.
description: Tool description.
properties: JSON schema properties.
required: List of required property names.
Returns:
JSON schema for the tool.
"""
schema = {
"type": "object",
"properties": properties,
}
if required:
schema["required"] = required
return schema

View File

@@ -15,7 +15,7 @@ ROOT_PATH = Path(__file__).resolve().parent.parent
if str(ROOT_PATH) not in sys_path: if str(ROOT_PATH) not in sys_path:
sys_path.insert(0, str(ROOT_PATH)) sys_path.insert(0, str(ROOT_PATH))
from src.common.database.database_model import Expression, Jargon, ModifiedBy from src.common.database.database_model import Expression, Jargon, ModifiedBy # noqa: E402
def build_argument_parser() -> ArgumentParser: def build_argument_parser() -> ArgumentParser:

View File

@@ -239,8 +239,7 @@ def load_utils_via_file(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_utils(monkeypatch): async def test_message_utils(monkeypatch):
load_message_via_file(monkeypatch) load_message_via_file(monkeypatch)
utils_module = load_utils_via_file(monkeypatch) load_utils_via_file(monkeypatch)
MessageUtils = utils_module.MessageUtils
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -22,9 +22,9 @@ if str(_root) not in sys.path:
if str(_maisaka_path) not in sys.path: if str(_maisaka_path) not in sys.path:
sys.path.insert(0, str(_maisaka_path)) sys.path.insert(0, str(_maisaka_path))
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager # noqa: E402
from src.maisaka.cli import BufferCLI from src.maisaka.cli import BufferCLI # noqa: E402
from src.maisaka.config import console from src.maisaka.config import console # noqa: E402
def main(): def main():

View File

@@ -36,8 +36,8 @@ def get_chat_name(chat_id: str) -> str:
elif chat_stream.user_nickname: elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊" return f"{chat_stream.user_nickname}的私聊"
if get_chat_manager: if _script_chat_manager:
chat_manager = get_chat_manager() chat_manager = _script_chat_manager
stream_name = chat_manager.get_stream_name(chat_id) stream_name = chat_manager.get_stream_name(chat_id)
if stream_name: if stream_name:
return stream_name return stream_name

View File

@@ -5,6 +5,7 @@ import sys
import time import time
import json import json
import importlib import importlib
from dataclasses import dataclass
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
@@ -23,7 +24,17 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import LLMUsage from src.common.database.database_model import LLMUsage
from maim_message import UserInfo, GroupInfo from src.common.data_models.mai_message_data_model import UserInfo, GroupInfo
try:
from maim_message import ChatStream, UserInfo, GroupInfo
except Exception:
@dataclass
class ChatStream:
stream_id: str
platform: str
user_info: UserInfo
group_info: GroupInfo
logger = get_logger("test_memory_retrieval") logger = get_logger("test_memory_retrieval")

View File

@@ -205,7 +205,7 @@ class HeartFChatting:
# TODO: Planner逻辑 # TODO: Planner逻辑
# TODO: 动作执行逻辑 # TODO: 动作执行逻辑
cycle_detail = self._end_cycle(current_cycle_detail) self._end_cycle(current_cycle_detail)
await asyncio.sleep(0.1) # 最小等待时间,避免过快循环 await asyncio.sleep(0.1) # 最小等待时间,避免过快循环
return True return True

View File

@@ -12,7 +12,6 @@ from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiv
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils from src.common.utils.utils_session import SessionUtils
from src.config.config import global_config
from src.platform_io.route_key_factory import RouteKeyFactory from src.platform_io.route_key_factory import RouteKeyFactory
from src.core.announcement_manager import global_announcement_manager from src.core.announcement_manager import global_announcement_manager
from src.plugin_runtime.component_query import component_query_service from src.plugin_runtime.component_query import component_query_service

View File

@@ -1135,9 +1135,6 @@ class DefaultReplyer:
return content, reasoning_content, model_name, tool_calls return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str): async def get_prompt_info(self, message: str, sender: str, target: str):
del message
del sender
del target
return "" return ""
related_info = "" related_info = ""
start_time = time.time() start_time = time.time()

View File

@@ -1058,12 +1058,9 @@ class StatisticOutputTask(AsyncTask):
from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
if chat_id in _stat_chat_manager.sessions: if chat_id in _stat_chat_manager.sessions:
session = _stat_chat_manager.sessions[chat_id]
name = _stat_chat_manager.get_session_name(chat_id) name = _stat_chat_manager.get_session_name(chat_id)
if name and name.strip(): if name and name.strip():
return name.strip() return name.strip()
if user_name and user_name.strip():
return user_name.strip()
# 如果从chat_stream获取失败尝试解析chat_id格式 # 如果从chat_stream获取失败尝试解析chat_id格式
if chat_id.startswith("g"): if chat_id.startswith("g"):

View File

@@ -92,7 +92,6 @@ class UniversalMessageSender:
""" """
# TODO: 重构至新的发送模型 # TODO: 重构至新的发送模型
message_preview = (message.processed_plain_text or "")[:200] message_preview = (message.processed_plain_text or "")[:200]
platform = message.platform
try: try:
# 尝试通过主 API 发送 # 尝试通过主 API 发送