feat: add a subagent frame
This commit is contained in:
521
agentlite/tests/tools/test_document_kg_tools.py
Normal file
521
agentlite/tests/tools/test_document_kg_tools.py
Normal file
@@ -0,0 +1,521 @@
|
||||
"""工具专项测试 - 文档提取和知识图谱工具
|
||||
|
||||
本模块测试基于数据基底的工具功能,包括:
|
||||
1. 文档读取和解析工具
|
||||
2. 实体提取工具
|
||||
3. 知识图谱查询工具
|
||||
4. 推理工具
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from agentlite import Agent, Message, TextPart, tool
|
||||
|
||||
|
||||
def tool_output(result: Any) -> Any:
|
||||
"""兼容旧式返回值和 ToolResult 返回值."""
|
||||
return getattr(result, "output", result)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 数据加载 fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
"""返回测试数据目录路径."""
|
||||
return Path(__file__).parent.parent / "data"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_article(data_dir: Path) -> str:
|
||||
"""加载样例文章."""
|
||||
return (data_dir / "documents" / "sample_article.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def technical_spec(data_dir: Path) -> str:
|
||||
"""加载技术规范文档."""
|
||||
return (data_dir / "documents" / "technical_spec.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def meeting_notes(data_dir: Path) -> str:
|
||||
"""加载会议记录."""
|
||||
return (data_dir / "documents" / "meeting_notes.txt").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph_entities(data_dir: Path) -> dict[str, Any]:
|
||||
"""加载知识图谱实体数据."""
|
||||
with open(data_dir / "knowledge_base" / "entities.json") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph_relations(data_dir: Path) -> dict[str, Any]:
|
||||
"""加载知识图谱关系数据."""
|
||||
with open(data_dir / "knowledge_base" / "relations.json") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_queries(data_dir: Path) -> list[dict[str, Any]]:
|
||||
"""加载图谱查询测试用例."""
|
||||
with open(data_dir / "knowledge_base" / "graph_queries.yaml") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return data.get("queries", [])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 知识图谱工具实现
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KnowledgeGraph:
|
||||
"""知识图谱内存存储."""
|
||||
|
||||
def __init__(self, entities: list[dict], relations: list[dict]):
|
||||
self._entities = {e["id"]: e for e in entities}
|
||||
self._relations = relations
|
||||
self._index_by_type: dict[str, list[str]] = {}
|
||||
self._index_by_name: dict[str, str] = {}
|
||||
|
||||
# 构建索引
|
||||
for entity_id, entity in self._entities.items():
|
||||
entity_type = entity.get("type", "Unknown")
|
||||
if entity_type not in self._index_by_type:
|
||||
self._index_by_type[entity_type] = []
|
||||
self._index_by_type[entity_type].append(entity_id)
|
||||
|
||||
name = entity.get("name", "")
|
||||
if name:
|
||||
self._index_by_name[name] = entity_id
|
||||
|
||||
def get_entity(self, entity_id: str) -> dict | None:
|
||||
"""获取实体."""
|
||||
return self._entities.get(entity_id)
|
||||
|
||||
def get_entity_by_name(self, name: str) -> dict | None:
|
||||
"""通过名称获取实体."""
|
||||
entity_id = self._index_by_name.get(name)
|
||||
if entity_id:
|
||||
return self._entities.get(entity_id)
|
||||
return None
|
||||
|
||||
def get_entities_by_type(self, entity_type: str) -> list[dict]:
|
||||
"""获取特定类型的所有实体."""
|
||||
entity_ids = self._index_by_type.get(entity_type, [])
|
||||
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
|
||||
|
||||
def get_relations(
|
||||
self, from_id: str | None = None, to_id: str | None = None, relation_type: str | None = None
|
||||
) -> list[dict]:
|
||||
"""获取关系."""
|
||||
results = []
|
||||
for rel in self._relations:
|
||||
if from_id and rel.get("from") != from_id:
|
||||
continue
|
||||
if to_id and rel.get("to") != to_id:
|
||||
continue
|
||||
if relation_type and rel.get("type") != relation_type:
|
||||
continue
|
||||
results.append(rel)
|
||||
return results
|
||||
|
||||
def get_neighbors(self, entity_id: str, relation_type: str | None = None) -> list[dict]:
|
||||
"""获取邻居实体."""
|
||||
relations = self.get_relations(from_id=entity_id, relation_type=relation_type)
|
||||
neighbors = []
|
||||
for rel in relations:
|
||||
target_id = rel.get("to")
|
||||
if target_id and target_id in self._entities:
|
||||
neighbors.append({"entity": self._entities[target_id], "relation": rel})
|
||||
return neighbors
|
||||
|
||||
def find_path(self, start_id: str, end_id: str, max_depth: int = 3) -> list[list[str]] | None:
|
||||
"""查找两个实体之间的路径."""
|
||||
if start_id == end_id:
|
||||
return [[start_id]]
|
||||
|
||||
if max_depth <= 0:
|
||||
return None
|
||||
|
||||
# BFS
|
||||
from collections import deque
|
||||
|
||||
queue = deque([(start_id, [start_id])])
|
||||
visited = {start_id}
|
||||
all_paths = []
|
||||
|
||||
while queue:
|
||||
current_id, path = queue.popleft()
|
||||
|
||||
if len(path) > max_depth + 1:
|
||||
continue
|
||||
|
||||
relations = self.get_relations(from_id=current_id)
|
||||
for rel in relations:
|
||||
next_id = rel.get("to")
|
||||
if not next_id:
|
||||
continue
|
||||
|
||||
new_path = path + [next_id]
|
||||
|
||||
if next_id == end_id:
|
||||
all_paths.append(new_path)
|
||||
elif next_id not in visited and len(new_path) <= max_depth:
|
||||
visited.add(next_id)
|
||||
queue.append((next_id, new_path))
|
||||
|
||||
return all_paths if all_paths else None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def knowledge_graph(knowledge_graph_entities, knowledge_graph_relations) -> KnowledgeGraph:
|
||||
"""创建知识图谱实例."""
|
||||
return KnowledgeGraph(
|
||||
entities=knowledge_graph_entities.get("entities", []),
|
||||
relations=knowledge_graph_relations.get("relations", []),
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 工具定义
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@tool()
|
||||
async def read_document(file_path: str) -> str:
|
||||
"""读取文档内容.
|
||||
|
||||
Args:
|
||||
file_path: 文档路径
|
||||
|
||||
Returns:
|
||||
文档内容
|
||||
"""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
return f"Error: File not found: {file_path}"
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
|
||||
@tool()
|
||||
async def extract_entities(text: str) -> str:
|
||||
"""从文本中提取实体.
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
JSON 格式的实体列表
|
||||
"""
|
||||
# 简化的实体提取 - 实际应使用 NLP 模型
|
||||
import re
|
||||
|
||||
entities = []
|
||||
|
||||
# 提取人名(简单的中文姓名匹配)
|
||||
person_pattern = r"[\u4e00-\u9fa5]{2,4}"
|
||||
potential_names = re.findall(person_pattern, text)
|
||||
common_names = ["张三", "李四", "王五", "赵六", "李飞飞", "吴恩达", "Yann LeCun"]
|
||||
|
||||
for name in potential_names:
|
||||
if name in common_names or len(name) == 3:
|
||||
entities.append({"type": "Person", "name": name})
|
||||
|
||||
# 提取公司/组织名
|
||||
org_pattern = r"(TechCorp|OpenAI|GitHub|Google)"
|
||||
orgs = re.findall(org_pattern, text)
|
||||
for org in set(orgs):
|
||||
entities.append({"type": "Organization", "name": org})
|
||||
|
||||
# 提取技术术语
|
||||
tech_pattern = r"(Python|TensorFlow|PyTorch|GPT-4|AI|LLM)"
|
||||
techs = re.findall(tech_pattern, text)
|
||||
for tech in set(techs):
|
||||
entities.append({"type": "Technology", "name": tech})
|
||||
|
||||
return json.dumps(entities, ensure_ascii=False)
|
||||
|
||||
|
||||
@tool()
|
||||
async def query_knowledge_graph(query_type: str, params: str) -> str:
|
||||
"""查询知识图谱.
|
||||
|
||||
Args:
|
||||
query_type: 查询类型 (person_relations, company_employees, technology_users, etc.)
|
||||
params: JSON 格式的查询参数
|
||||
|
||||
Returns:
|
||||
查询结果
|
||||
"""
|
||||
# 这里使用全局的 kg 实例,实际应在 Agent 初始化时注入
|
||||
try:
|
||||
params_dict = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
return json.dumps({"error": "Invalid JSON params"})
|
||||
|
||||
# 简化实现 - 实际应基于知识图谱查询
|
||||
result = {"query_type": query_type, "params": params_dict, "results": []}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
@tool()
|
||||
async def reason_about_path(start_entity: str, end_entity: str) -> str:
|
||||
"""推理两个实体之间的关系路径.
|
||||
|
||||
Args:
|
||||
start_entity: 起始实体名称
|
||||
end_entity: 目标实体名称
|
||||
|
||||
Returns:
|
||||
推理结果
|
||||
"""
|
||||
return json.dumps(
|
||||
{
|
||||
"start": start_entity,
|
||||
"end": end_entity,
|
||||
"reasoning": f"分析 {start_entity} 到 {end_entity} 的关系链...",
|
||||
"path": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestDocumentTools:
|
||||
"""文档工具测试."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_document(self, data_dir: Path, sample_article: str):
|
||||
"""测试文档读取工具."""
|
||||
result = tool_output(await read_document(str(data_dir / "documents" / "sample_article.md")))
|
||||
|
||||
assert "人工智能" in result
|
||||
assert "GitHub Copilot" in result
|
||||
assert "张三" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_document_not_found(self):
|
||||
"""测试读取不存在的文档."""
|
||||
result = tool_output(await read_document("/nonexistent/file.md"))
|
||||
|
||||
assert "Error" in result
|
||||
assert "not found" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_entities_from_article(self, sample_article: str):
|
||||
"""测试从文章中提取实体."""
|
||||
result = tool_output(await extract_entities(sample_article))
|
||||
entities = json.loads(result)
|
||||
|
||||
# 验证提取到实体
|
||||
assert len(entities) > 0
|
||||
|
||||
# 验证实体类型
|
||||
entity_names = [e["name"] for e in entities]
|
||||
assert "张三" in entity_names
|
||||
assert "TechCorp" in entity_names or "OpenAI" in entity_names
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestKnowledgeGraphTools:
|
||||
"""知识图谱工具测试."""
|
||||
|
||||
def test_knowledge_graph_initialization(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试知识图谱初始化."""
|
||||
# 验证实体数量
|
||||
entity = knowledge_graph.get_entity_by_name("张三")
|
||||
assert entity is not None
|
||||
assert entity["type"] == "Person"
|
||||
|
||||
# 验证公司实体
|
||||
company = knowledge_graph.get_entity_by_name("TechCorp")
|
||||
assert company is not None
|
||||
assert company["type"] == "Company"
|
||||
|
||||
def test_get_entities_by_type(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试按类型获取实体."""
|
||||
persons = knowledge_graph.get_entities_by_type("Person")
|
||||
assert len(persons) >= 3 # 张三、李四、李飞飞
|
||||
|
||||
technologies = knowledge_graph.get_entities_by_type("Technology")
|
||||
assert len(technologies) >= 2 # Python、OpenAI API
|
||||
|
||||
def test_get_relations(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试获取关系."""
|
||||
# 获取张三的所有关系
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
assert zhangsan is not None
|
||||
|
||||
relations = knowledge_graph.get_relations(from_id=zhangsan["id"])
|
||||
assert len(relations) >= 2 # works_for, uses
|
||||
|
||||
# 验证关系类型
|
||||
relation_types = [r["type"] for r in relations]
|
||||
assert "works_for" in relation_types
|
||||
assert "uses" in relation_types
|
||||
|
||||
def test_get_neighbors(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试获取邻居节点."""
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
assert zhangsan is not None
|
||||
|
||||
neighbors = knowledge_graph.get_neighbors(zhangsan["id"])
|
||||
assert len(neighbors) >= 2
|
||||
|
||||
# 验证邻居包含 TechCorp
|
||||
neighbor_names = [n["entity"]["name"] for n in neighbors]
|
||||
assert "TechCorp" in neighbor_names
|
||||
|
||||
def test_find_path(self, knowledge_graph: KnowledgeGraph):
|
||||
"""测试查找路径."""
|
||||
zhangsan = knowledge_graph.get_entity_by_name("张三")
|
||||
techcorp = knowledge_graph.get_entity_by_name("TechCorp")
|
||||
|
||||
assert zhangsan is not None
|
||||
assert techcorp is not None
|
||||
|
||||
paths = knowledge_graph.find_path(zhangsan["id"], techcorp["id"])
|
||||
assert paths is not None
|
||||
assert len(paths) > 0
|
||||
|
||||
# 验证路径长度
|
||||
first_path = paths[0]
|
||||
assert len(first_path) == 2 # 张三 -> TechCorp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_knowledge_graph(self):
|
||||
"""测试知识图谱查询工具."""
|
||||
params = json.dumps({"entity_name": "张三"})
|
||||
result = tool_output(await query_knowledge_graph("person_relations", params))
|
||||
|
||||
data = json.loads(result)
|
||||
assert data["query_type"] == "person_relations"
|
||||
assert "params" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reason_about_path(self):
|
||||
"""测试路径推理工具."""
|
||||
result = tool_output(await reason_about_path("张三", "OpenAI"))
|
||||
|
||||
data = json.loads(result)
|
||||
assert data["start"] == "张三"
|
||||
assert data["end"] == "OpenAI"
|
||||
assert "reasoning" in data
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestDataIntegrity:
|
||||
"""数据完整性测试."""
|
||||
|
||||
def test_entities_json_valid(self, knowledge_graph_entities: dict):
|
||||
"""验证实体 JSON 格式正确."""
|
||||
assert "entities" in knowledge_graph_entities
|
||||
entities = knowledge_graph_entities["entities"]
|
||||
assert len(entities) > 0
|
||||
|
||||
# 验证每个实体都有必需的字段
|
||||
for entity in entities:
|
||||
assert "id" in entity
|
||||
assert "type" in entity
|
||||
assert "name" in entity
|
||||
|
||||
def test_relations_json_valid(
|
||||
self, knowledge_graph_relations: dict, knowledge_graph_entities: dict
|
||||
):
|
||||
"""验证关系 JSON 格式正确且引用的实体存在."""
|
||||
assert "relations" in knowledge_graph_relations
|
||||
relations = knowledge_graph_relations["relations"]
|
||||
|
||||
entity_ids = {e["id"] for e in knowledge_graph_entities["entities"]}
|
||||
|
||||
for relation in relations:
|
||||
assert "from" in relation
|
||||
assert "to" in relation
|
||||
assert "type" in relation
|
||||
|
||||
# 验证引用的实体存在
|
||||
assert relation["from"] in entity_ids, f"Entity {relation['from']} not found"
|
||||
assert relation["to"] in entity_ids, f"Entity {relation['to']} not found"
|
||||
|
||||
def test_graph_queries_yaml_valid(self, graph_queries: list):
|
||||
"""验证查询 YAML 格式正确."""
|
||||
assert len(graph_queries) > 0
|
||||
|
||||
for query in graph_queries:
|
||||
assert "id" in query
|
||||
assert "description" in query
|
||||
assert "query" in query
|
||||
assert "expected_results" in query
|
||||
|
||||
def test_documents_exist(self, data_dir: Path):
|
||||
"""验证测试文档存在且非空."""
|
||||
docs_dir = data_dir / "documents"
|
||||
|
||||
sample_article = docs_dir / "sample_article.md"
|
||||
assert sample_article.exists()
|
||||
assert sample_article.stat().st_size > 0
|
||||
|
||||
tech_spec = docs_dir / "technical_spec.md"
|
||||
assert tech_spec.exists()
|
||||
assert tech_spec.stat().st_size > 0
|
||||
|
||||
meeting_notes = docs_dir / "meeting_notes.txt"
|
||||
assert meeting_notes.exists()
|
||||
assert meeting_notes.stat().st_size > 0
|
||||
|
||||
|
||||
@pytest.mark.tools
|
||||
class TestAgentWithTools:
|
||||
"""Agent 集成工具测试."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_with_document_tools(self, mock_provider, data_dir: Path):
|
||||
"""测试带有文档工具的 Agent."""
|
||||
mock_provider.add_text_response("我已经读取了文档")
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=[read_document],
|
||||
system_prompt="你是一个文档助手,可以读取和分析文档。",
|
||||
)
|
||||
|
||||
response = await agent.run(f"请读取文档 {data_dir / 'documents' / 'sample_article.md'}")
|
||||
|
||||
assert "文档" in response or "读取" in response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_with_kg_tools(self, mock_provider):
|
||||
"""测试带有知识图谱工具的 Agent."""
|
||||
mock_provider.add_text_response("张三在 TechCorp 工作")
|
||||
|
||||
agent = Agent(
|
||||
provider=mock_provider,
|
||||
tools=[query_knowledge_graph, reason_about_path],
|
||||
system_prompt="你是一个知识图谱助手,可以查询实体关系。",
|
||||
)
|
||||
|
||||
response = await agent.run("张三在哪里工作?")
|
||||
|
||||
assert response is not None
|
||||
assert len(response) > 0
|
||||
Reference in New Issue
Block a user