Files
mai-bot/agentlite/tests/tools/test_document_kg_tools.py
tcmofashi 7b9e1cf925 ruff
2026-04-03 23:18:30 +08:00

522 lines
17 KiB
Python

"""工具专项测试 - 文档提取和知识图谱工具
本模块测试基于数据基底的工具功能,包括:
1. 文档读取和解析工具
2. 实体提取工具
3. 知识图谱查询工具
4. 推理工具
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import pytest
import yaml
from agentlite import Agent, tool
def tool_output(result: Any) -> Any:
"""兼容旧式返回值和 ToolResult 返回值."""
return getattr(result, "output", result)
# =============================================================================
# 数据加载 fixtures
# =============================================================================
@pytest.fixture
def data_dir() -> Path:
"""返回测试数据目录路径."""
return Path(__file__).parent.parent / "data"
@pytest.fixture
def sample_article(data_dir: Path) -> str:
"""加载样例文章."""
return (data_dir / "documents" / "sample_article.md").read_text(encoding="utf-8")
@pytest.fixture
def technical_spec(data_dir: Path) -> str:
"""加载技术规范文档."""
return (data_dir / "documents" / "technical_spec.md").read_text(encoding="utf-8")
@pytest.fixture
def meeting_notes(data_dir: Path) -> str:
"""加载会议记录."""
return (data_dir / "documents" / "meeting_notes.txt").read_text(encoding="utf-8")
@pytest.fixture
def knowledge_graph_entities(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱实体数据."""
with open(data_dir / "knowledge_base" / "entities.json") as f:
return json.load(f)
@pytest.fixture
def knowledge_graph_relations(data_dir: Path) -> dict[str, Any]:
"""加载知识图谱关系数据."""
with open(data_dir / "knowledge_base" / "relations.json") as f:
return json.load(f)
@pytest.fixture
def graph_queries(data_dir: Path) -> list[dict[str, Any]]:
"""加载图谱查询测试用例."""
with open(data_dir / "knowledge_base" / "graph_queries.yaml") as f:
data = yaml.safe_load(f)
return data.get("queries", [])
# =============================================================================
# 知识图谱工具实现
# =============================================================================
class KnowledgeGraph:
"""知识图谱内存存储."""
def __init__(self, entities: list[dict], relations: list[dict]):
self._entities = {e["id"]: e for e in entities}
self._relations = relations
self._index_by_type: dict[str, list[str]] = {}
self._index_by_name: dict[str, str] = {}
# 构建索引
for entity_id, entity in self._entities.items():
entity_type = entity.get("type", "Unknown")
if entity_type not in self._index_by_type:
self._index_by_type[entity_type] = []
self._index_by_type[entity_type].append(entity_id)
name = entity.get("name", "")
if name:
self._index_by_name[name] = entity_id
def get_entity(self, entity_id: str) -> dict | None:
"""获取实体."""
return self._entities.get(entity_id)
def get_entity_by_name(self, name: str) -> dict | None:
"""通过名称获取实体."""
entity_id = self._index_by_name.get(name)
if entity_id:
return self._entities.get(entity_id)
return None
def get_entities_by_type(self, entity_type: str) -> list[dict]:
"""获取特定类型的所有实体."""
entity_ids = self._index_by_type.get(entity_type, [])
return [self._entities[eid] for eid in entity_ids if eid in self._entities]
def get_relations(
self, from_id: str | None = None, to_id: str | None = None, relation_type: str | None = None
) -> list[dict]:
"""获取关系."""
results = []
for rel in self._relations:
if from_id and rel.get("from") != from_id:
continue
if to_id and rel.get("to") != to_id:
continue
if relation_type and rel.get("type") != relation_type:
continue
results.append(rel)
return results
def get_neighbors(self, entity_id: str, relation_type: str | None = None) -> list[dict]:
"""获取邻居实体."""
relations = self.get_relations(from_id=entity_id, relation_type=relation_type)
neighbors = []
for rel in relations:
target_id = rel.get("to")
if target_id and target_id in self._entities:
neighbors.append({"entity": self._entities[target_id], "relation": rel})
return neighbors
def find_path(self, start_id: str, end_id: str, max_depth: int = 3) -> list[list[str]] | None:
"""查找两个实体之间的路径."""
if start_id == end_id:
return [[start_id]]
if max_depth <= 0:
return None
# BFS
from collections import deque
queue = deque([(start_id, [start_id])])
visited = {start_id}
all_paths = []
while queue:
current_id, path = queue.popleft()
if len(path) > max_depth + 1:
continue
relations = self.get_relations(from_id=current_id)
for rel in relations:
next_id = rel.get("to")
if not next_id:
continue
new_path = path + [next_id]
if next_id == end_id:
all_paths.append(new_path)
elif next_id not in visited and len(new_path) <= max_depth:
visited.add(next_id)
queue.append((next_id, new_path))
return all_paths if all_paths else None
@pytest.fixture
def knowledge_graph(knowledge_graph_entities, knowledge_graph_relations) -> KnowledgeGraph:
"""创建知识图谱实例."""
return KnowledgeGraph(
entities=knowledge_graph_entities.get("entities", []),
relations=knowledge_graph_relations.get("relations", []),
)
# =============================================================================
# 工具定义
# =============================================================================
@tool()
async def read_document(file_path: str) -> str:
"""读取文档内容.
Args:
file_path: 文档路径
Returns:
文档内容
"""
path = Path(file_path)
if not path.exists():
return f"Error: File not found: {file_path}"
try:
return path.read_text(encoding="utf-8")
except Exception as e:
return f"Error reading file: {e}"
@tool()
async def extract_entities(text: str) -> str:
"""从文本中提取实体.
Args:
text: 输入文本
Returns:
JSON 格式的实体列表
"""
# 简化的实体提取 - 实际应使用 NLP 模型
import re
entities = []
# 提取人名(简单的中文姓名匹配)
person_pattern = r"[\u4e00-\u9fa5]{2,4}"
potential_names = re.findall(person_pattern, text)
common_names = ["张三", "李四", "王五", "赵六", "李飞飞", "吴恩达", "Yann LeCun"]
for name in potential_names:
if name in common_names or len(name) == 3:
entities.append({"type": "Person", "name": name})
# 提取公司/组织名
org_pattern = r"(TechCorp|OpenAI|GitHub|Google)"
orgs = re.findall(org_pattern, text)
for org in set(orgs):
entities.append({"type": "Organization", "name": org})
# 提取技术术语
tech_pattern = r"(Python|TensorFlow|PyTorch|GPT-4|AI|LLM)"
techs = re.findall(tech_pattern, text)
for tech in set(techs):
entities.append({"type": "Technology", "name": tech})
return json.dumps(entities, ensure_ascii=False)
@tool()
async def query_knowledge_graph(query_type: str, params: str) -> str:
"""查询知识图谱.
Args:
query_type: 查询类型 (person_relations, company_employees, technology_users, etc.)
params: JSON 格式的查询参数
Returns:
查询结果
"""
# 这里使用全局的 kg 实例,实际应在 Agent 初始化时注入
try:
params_dict = json.loads(params)
except json.JSONDecodeError:
return json.dumps({"error": "Invalid JSON params"})
# 简化实现 - 实际应基于知识图谱查询
result = {"query_type": query_type, "params": params_dict, "results": []}
return json.dumps(result, ensure_ascii=False)
@tool()
async def reason_about_path(start_entity: str, end_entity: str) -> str:
"""推理两个实体之间的关系路径.
Args:
start_entity: 起始实体名称
end_entity: 目标实体名称
Returns:
推理结果
"""
return json.dumps(
{
"start": start_entity,
"end": end_entity,
"reasoning": f"分析 {start_entity}{end_entity} 的关系链...",
"path": [],
},
ensure_ascii=False,
)
# =============================================================================
# 测试用例
# =============================================================================
@pytest.mark.tools
class TestDocumentTools:
"""文档工具测试."""
@pytest.mark.asyncio
async def test_read_document(self, data_dir: Path, sample_article: str):
"""测试文档读取工具."""
result = tool_output(await read_document(str(data_dir / "documents" / "sample_article.md")))
assert "人工智能" in result
assert "GitHub Copilot" in result
assert "张三" in result
@pytest.mark.asyncio
async def test_read_document_not_found(self):
"""测试读取不存在的文档."""
result = tool_output(await read_document("/nonexistent/file.md"))
assert "Error" in result
assert "not found" in result.lower()
@pytest.mark.asyncio
async def test_extract_entities_from_article(self, sample_article: str):
"""测试从文章中提取实体."""
result = tool_output(await extract_entities(sample_article))
entities = json.loads(result)
# 验证提取到实体
assert len(entities) > 0
# 验证实体类型
entity_names = [e["name"] for e in entities]
assert "张三" in entity_names
assert "TechCorp" in entity_names or "OpenAI" in entity_names
@pytest.mark.tools
class TestKnowledgeGraphTools:
"""知识图谱工具测试."""
def test_knowledge_graph_initialization(self, knowledge_graph: KnowledgeGraph):
"""测试知识图谱初始化."""
# 验证实体数量
entity = knowledge_graph.get_entity_by_name("张三")
assert entity is not None
assert entity["type"] == "Person"
# 验证公司实体
company = knowledge_graph.get_entity_by_name("TechCorp")
assert company is not None
assert company["type"] == "Company"
def test_get_entities_by_type(self, knowledge_graph: KnowledgeGraph):
"""测试按类型获取实体."""
persons = knowledge_graph.get_entities_by_type("Person")
assert len(persons) >= 3 # 张三、李四、李飞飞
technologies = knowledge_graph.get_entities_by_type("Technology")
assert len(technologies) >= 2 # Python、OpenAI API
def test_get_relations(self, knowledge_graph: KnowledgeGraph):
"""测试获取关系."""
# 获取张三的所有关系
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
relations = knowledge_graph.get_relations(from_id=zhangsan["id"])
assert len(relations) >= 2 # works_for, uses
# 验证关系类型
relation_types = [r["type"] for r in relations]
assert "works_for" in relation_types
assert "uses" in relation_types
def test_get_neighbors(self, knowledge_graph: KnowledgeGraph):
"""测试获取邻居节点."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
assert zhangsan is not None
neighbors = knowledge_graph.get_neighbors(zhangsan["id"])
assert len(neighbors) >= 2
# 验证邻居包含 TechCorp
neighbor_names = [n["entity"]["name"] for n in neighbors]
assert "TechCorp" in neighbor_names
def test_find_path(self, knowledge_graph: KnowledgeGraph):
"""测试查找路径."""
zhangsan = knowledge_graph.get_entity_by_name("张三")
techcorp = knowledge_graph.get_entity_by_name("TechCorp")
assert zhangsan is not None
assert techcorp is not None
paths = knowledge_graph.find_path(zhangsan["id"], techcorp["id"])
assert paths is not None
assert len(paths) > 0
# 验证路径长度
first_path = paths[0]
assert len(first_path) == 2 # 张三 -> TechCorp
@pytest.mark.asyncio
async def test_query_knowledge_graph(self):
"""测试知识图谱查询工具."""
params = json.dumps({"entity_name": "张三"})
result = tool_output(await query_knowledge_graph("person_relations", params))
data = json.loads(result)
assert data["query_type"] == "person_relations"
assert "params" in data
@pytest.mark.asyncio
async def test_reason_about_path(self):
"""测试路径推理工具."""
result = tool_output(await reason_about_path("张三", "OpenAI"))
data = json.loads(result)
assert data["start"] == "张三"
assert data["end"] == "OpenAI"
assert "reasoning" in data
@pytest.mark.tools
class TestDataIntegrity:
"""数据完整性测试."""
def test_entities_json_valid(self, knowledge_graph_entities: dict):
"""验证实体 JSON 格式正确."""
assert "entities" in knowledge_graph_entities
entities = knowledge_graph_entities["entities"]
assert len(entities) > 0
# 验证每个实体都有必需的字段
for entity in entities:
assert "id" in entity
assert "type" in entity
assert "name" in entity
def test_relations_json_valid(
self, knowledge_graph_relations: dict, knowledge_graph_entities: dict
):
"""验证关系 JSON 格式正确且引用的实体存在."""
assert "relations" in knowledge_graph_relations
relations = knowledge_graph_relations["relations"]
entity_ids = {e["id"] for e in knowledge_graph_entities["entities"]}
for relation in relations:
assert "from" in relation
assert "to" in relation
assert "type" in relation
# 验证引用的实体存在
assert relation["from"] in entity_ids, f"Entity {relation['from']} not found"
assert relation["to"] in entity_ids, f"Entity {relation['to']} not found"
def test_graph_queries_yaml_valid(self, graph_queries: list):
"""验证查询 YAML 格式正确."""
assert len(graph_queries) > 0
for query in graph_queries:
assert "id" in query
assert "description" in query
assert "query" in query
assert "expected_results" in query
def test_documents_exist(self, data_dir: Path):
"""验证测试文档存在且非空."""
docs_dir = data_dir / "documents"
sample_article = docs_dir / "sample_article.md"
assert sample_article.exists()
assert sample_article.stat().st_size > 0
tech_spec = docs_dir / "technical_spec.md"
assert tech_spec.exists()
assert tech_spec.stat().st_size > 0
meeting_notes = docs_dir / "meeting_notes.txt"
assert meeting_notes.exists()
assert meeting_notes.stat().st_size > 0
@pytest.mark.tools
class TestAgentWithTools:
"""Agent 集成工具测试."""
@pytest.mark.asyncio
async def test_agent_with_document_tools(self, mock_provider, data_dir: Path):
"""测试带有文档工具的 Agent."""
mock_provider.add_text_response("我已经读取了文档")
agent = Agent(
provider=mock_provider,
tools=[read_document],
system_prompt="你是一个文档助手,可以读取和分析文档。",
)
response = await agent.run(f"请读取文档 {data_dir / 'documents' / 'sample_article.md'}")
assert "文档" in response or "读取" in response
@pytest.mark.asyncio
async def test_agent_with_kg_tools(self, mock_provider):
"""测试带有知识图谱工具的 Agent."""
mock_provider.add_text_response("张三在 TechCorp 工作")
agent = Agent(
provider=mock_provider,
tools=[query_knowledge_graph, reason_about_path],
system_prompt="你是一个知识图谱助手,可以查询实体关系。",
)
response = await agent.run("张三在哪里工作?")
assert response is not None
assert len(response) > 0