398 lines
14 KiB
Python
398 lines
14 KiB
Python
"""知识库图谱可视化 API 路由"""
|
||
|
||
import logging
|
||
from typing import Any, List, Optional, Tuple
|
||
|
||
from fastapi import APIRouter, Depends, Query
|
||
from pydantic import BaseModel
|
||
|
||
from src.config.config import global_config
|
||
from src.webui.dependencies import require_auth
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"], dependencies=[Depends(require_auth)])
|
||
|
||
# 延迟初始化的轻量级 embedding store(只读,仅用于获取段落完整文本)
|
||
_paragraph_store_cache: Any = None
|
||
|
||
|
||
def _get_embedding_dir() -> str:
|
||
"""获取 embedding 数据目录。"""
|
||
import os
|
||
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||
return os.path.join(root_path, "data/embedding")
|
||
|
||
|
||
def _get_paragraph_store():
|
||
"""延迟加载段落 embedding store(只读模式,轻量级)
|
||
|
||
Returns:
|
||
EmbeddingStore | None: 如果配置启用则返回store,否则返回None
|
||
"""
|
||
# 检查配置是否启用
|
||
if not global_config.webui.enable_paragraph_content:
|
||
return None
|
||
|
||
global _paragraph_store_cache
|
||
if _paragraph_store_cache is not None:
|
||
return _paragraph_store_cache
|
||
|
||
try:
|
||
from src.chat.knowledge.embedding_store import EmbeddingStore
|
||
|
||
# 只加载段落 embedding store(轻量级)
|
||
paragraph_store = EmbeddingStore(
|
||
namespace="paragraph",
|
||
dir_path=_get_embedding_dir(),
|
||
max_workers=1, # 只读不需要多线程
|
||
chunk_size=100,
|
||
)
|
||
paragraph_store.load_from_file()
|
||
|
||
_paragraph_store_cache = paragraph_store
|
||
logger.info(f"成功加载段落 embedding store,包含 {len(paragraph_store.store)} 个段落")
|
||
return paragraph_store
|
||
except Exception as e:
|
||
logger.warning(f"加载段落 embedding store 失败: {e}")
|
||
return None
|
||
|
||
|
||
def _get_paragraph_content(node_id: str) -> Tuple[Optional[str], bool]:
|
||
"""从 embedding store 获取段落完整内容
|
||
|
||
Args:
|
||
node_id: 段落节点ID,格式为 'paragraph-{hash}'
|
||
|
||
Returns:
|
||
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
|
||
"""
|
||
try:
|
||
paragraph_store = _get_paragraph_store()
|
||
if paragraph_store is None:
|
||
# 功能未启用
|
||
return None, False
|
||
|
||
# 从 store 中获取完整内容
|
||
paragraph_item = paragraph_store.store.get(node_id)
|
||
if paragraph_item is not None:
|
||
# paragraph_item 是 EmbeddingStoreItem,其 str 属性包含完整文本
|
||
if content := getattr(paragraph_item, "str", ""):
|
||
return content, True
|
||
return None, True
|
||
except Exception as e:
|
||
logger.debug(f"获取段落内容失败: {e}")
|
||
return None, True
|
||
|
||
|
||
class KnowledgeNode(BaseModel):
|
||
"""知识节点"""
|
||
|
||
id: str
|
||
type: str # 'entity' or 'paragraph'
|
||
content: str
|
||
create_time: Optional[float] = None
|
||
|
||
|
||
class KnowledgeEdge(BaseModel):
|
||
"""知识边"""
|
||
|
||
source: str
|
||
target: str
|
||
weight: float
|
||
create_time: Optional[float] = None
|
||
update_time: Optional[float] = None
|
||
|
||
|
||
class KnowledgeGraph(BaseModel):
|
||
"""知识图谱"""
|
||
|
||
nodes: List[KnowledgeNode]
|
||
edges: List[KnowledgeEdge]
|
||
|
||
|
||
class KnowledgeStats(BaseModel):
|
||
"""知识库统计信息"""
|
||
|
||
total_nodes: int
|
||
total_edges: int
|
||
entity_nodes: int
|
||
paragraph_nodes: int
|
||
avg_connections: float
|
||
|
||
|
||
def _load_kg_manager():
|
||
"""延迟加载 KGManager"""
|
||
try:
|
||
from src.chat.knowledge.kg_manager import KGManager
|
||
|
||
kg_manager = KGManager()
|
||
kg_manager.load_from_file()
|
||
return kg_manager
|
||
except Exception as e:
|
||
logger.error(f"加载 KGManager 失败: {e}")
|
||
return None
|
||
|
||
|
||
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||
"""将 DiGraph 转换为 JSON 格式"""
|
||
if kg_manager is None or kg_manager.graph is None:
|
||
return KnowledgeGraph(nodes=[], edges=[])
|
||
|
||
graph = kg_manager.graph
|
||
nodes = []
|
||
edges = []
|
||
|
||
# 转换节点
|
||
node_list = graph.get_node_list()
|
||
for node_id in node_list:
|
||
try:
|
||
node_data = graph[node_id]
|
||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||
|
||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||
if node_type == "paragraph":
|
||
full_content, _ = _get_paragraph_content(node_id)
|
||
content = (
|
||
full_content
|
||
if full_content is not None
|
||
else (node_data["content"] if "content" in node_data else node_id)
|
||
)
|
||
else:
|
||
content = node_data["content"] if "content" in node_data else node_id
|
||
|
||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||
|
||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||
except Exception as e:
|
||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||
continue
|
||
|
||
# 转换边
|
||
edge_list = graph.get_edge_list()
|
||
for edge_tuple in edge_list:
|
||
try:
|
||
# edge_tuple 是 (source, target) 元组
|
||
source, target = edge_tuple[0], edge_tuple[1]
|
||
# 通过 graph[source, target] 获取边的属性数据
|
||
edge_data = graph[source, target]
|
||
|
||
# edge_data 支持 [] 操作符但不支持 .get()
|
||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||
|
||
edges.append(
|
||
KnowledgeEdge(
|
||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||
continue
|
||
|
||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||
|
||
|
||
@router.get("/graph", response_model=KnowledgeGraph)
|
||
async def get_knowledge_graph(
|
||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||
):
|
||
"""获取知识图谱(限制节点数量)
|
||
|
||
Args:
|
||
limit: 返回的最大节点数,默认 100,最大 10000
|
||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||
|
||
Returns:
|
||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||
"""
|
||
try:
|
||
kg_manager = _load_kg_manager()
|
||
if kg_manager is None:
|
||
logger.warning("KGManager 未初始化,返回空图谱")
|
||
return KnowledgeGraph(nodes=[], edges=[])
|
||
|
||
graph = kg_manager.graph
|
||
all_node_list = graph.get_node_list()
|
||
|
||
# 按类型过滤节点
|
||
if node_type == "entity":
|
||
all_node_list = [
|
||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||
]
|
||
elif node_type == "paragraph":
|
||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||
|
||
# 限制节点数量
|
||
total_nodes = len(all_node_list)
|
||
if len(all_node_list) > limit:
|
||
node_list = all_node_list[:limit]
|
||
else:
|
||
node_list = all_node_list
|
||
|
||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||
|
||
# 转换节点
|
||
nodes = []
|
||
node_ids = set()
|
||
for node_id in node_list:
|
||
try:
|
||
node_data = graph[node_id]
|
||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||
|
||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||
if node_type_val == "paragraph":
|
||
full_content, _ = _get_paragraph_content(node_id)
|
||
content = (
|
||
full_content
|
||
if full_content is not None
|
||
else (node_data["content"] if "content" in node_data else node_id)
|
||
)
|
||
else:
|
||
content = node_data["content"] if "content" in node_data else node_id
|
||
|
||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||
|
||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||
node_ids.add(node_id)
|
||
except Exception as e:
|
||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||
continue
|
||
|
||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||
edges = []
|
||
edge_list = graph.get_edge_list()
|
||
for edge_tuple in edge_list:
|
||
try:
|
||
source, target = edge_tuple[0], edge_tuple[1]
|
||
# 只包含两端都在当前节点集中的边
|
||
if source not in node_ids or target not in node_ids:
|
||
continue
|
||
|
||
edge_data = graph[source, target]
|
||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||
|
||
edges.append(
|
||
KnowledgeEdge(
|
||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||
)
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||
continue
|
||
|
||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||
return graph_data
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||
return KnowledgeGraph(nodes=[], edges=[])
|
||
|
||
|
||
@router.get("/stats", response_model=KnowledgeStats)
|
||
async def get_knowledge_stats():
|
||
"""获取知识库统计信息
|
||
|
||
Returns:
|
||
KnowledgeStats: 统计信息
|
||
"""
|
||
try:
|
||
kg_manager = _load_kg_manager()
|
||
if kg_manager is None or kg_manager.graph is None:
|
||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||
|
||
graph = kg_manager.graph
|
||
node_list = graph.get_node_list()
|
||
edge_list = graph.get_edge_list()
|
||
|
||
total_nodes = len(node_list)
|
||
total_edges = len(edge_list)
|
||
|
||
# 统计节点类型
|
||
entity_nodes = 0
|
||
paragraph_nodes = 0
|
||
for node_id in node_list:
|
||
try:
|
||
node_data = graph[node_id]
|
||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||
if node_type == "ent":
|
||
entity_nodes += 1
|
||
elif node_type == "pg":
|
||
paragraph_nodes += 1
|
||
except Exception:
|
||
continue
|
||
|
||
# 计算平均连接数
|
||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||
|
||
return KnowledgeStats(
|
||
total_nodes=total_nodes,
|
||
total_edges=total_edges,
|
||
entity_nodes=entity_nodes,
|
||
paragraph_nodes=paragraph_nodes,
|
||
avg_connections=round(avg_connections, 2),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||
|
||
|
||
@router.get("/search", response_model=List[KnowledgeNode])
|
||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||
"""搜索知识节点
|
||
|
||
Args:
|
||
query: 搜索关键词
|
||
|
||
Returns:
|
||
List[KnowledgeNode]: 匹配的节点列表
|
||
"""
|
||
try:
|
||
kg_manager = _load_kg_manager()
|
||
if kg_manager is None or kg_manager.graph is None:
|
||
return []
|
||
|
||
graph = kg_manager.graph
|
||
node_list = graph.get_node_list()
|
||
results = []
|
||
query_lower = query.lower()
|
||
|
||
# 在节点内容中搜索
|
||
for node_id in node_list:
|
||
try:
|
||
node_data = graph[node_id]
|
||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||
|
||
# 对于段落节点,尝试从 embedding store 获取完整内容
|
||
if node_type == "paragraph":
|
||
full_content, _ = _get_paragraph_content(node_id)
|
||
content = (
|
||
full_content
|
||
if full_content is not None
|
||
else (node_data["content"] if "content" in node_data else node_id)
|
||
)
|
||
else:
|
||
content = node_data["content"] if "content" in node_data else node_id
|
||
|
||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||
except Exception:
|
||
continue
|
||
|
||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||
return results[:50] # 限制返回数量
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||
return []
|