feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑

This commit is contained in:
DrSmoothl
2026-04-03 23:35:16 +08:00
parent 8bfe3e7036
commit a2431e677e
8 changed files with 483 additions and 162 deletions

View File

@@ -1,31 +1,32 @@
from dataclasses import dataclass
import json
import os
import math
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple
import json
import math
import os
import faiss
import numpy as np
import pandas as pd
# import tqdm
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.config.config import global_config
from src.services.embedding_service import EmbeddingServiceClient
from .global_logger import logger
from .utils.hash import get_sha256
install(extra_lines=3)
@@ -133,19 +134,20 @@ class EmbeddingStore:
return [f"{namespace}-{get_sha256(t)}" for t in texts]
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
"""以同步方式获取单条字符串的嵌入向量。
Args:
s: 待编码的文本内容。
Returns:
List[float]: 嵌入向量;失败时返回空列表。
"""
try:
# 创建新的服务层实例
from src.services.llm_service import LLMServiceClient
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
# 使用新的事件循环运行异步方法
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding_client = EmbeddingServiceClient(
task_name="embedding",
request_type="embedding",
)
embedding_result = embedding_client.embed_text_sync(s)
embedding = embedding_result.embedding
if embedding and len(embedding) > 0:
@@ -157,17 +159,15 @@ class EmbeddingStore:
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
pass
def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
self,
strs: List[str],
chunk_size: int = 10,
max_workers: int = 10,
progress_callback: Callable[[int], None] | None = None,
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
@@ -190,53 +190,42 @@ class EmbeddingStore:
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
"""处理单个数据块
# 为每个线程创建独立的服务层实例
from src.services.llm_service import LLMServiceClient
Args:
chunk_data: 数据块起始索引与字符串列表。
Returns:
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
"""
start_idx, chunk_strs = chunk_data
chunk_results: List[Tuple[int, str, List[float]]] = []
try:
# 创建线程专用的服务层实例
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 在线程中创建独立的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding = embedding_result.embedding
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
embedding_client = EmbeddingServiceClient(
task_name="embedding",
request_type="embedding",
)
embedding_results = embedding_client.embed_texts_sync(
chunk_strs,
max_concurrent=1,
)
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
embedding = embedding_result.embedding
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding))
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
logger.error(f"创建 EmbeddingService 实例失败: {e}")
for i, s in enumerate(chunk_strs):
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)