feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑
This commit is contained in:
@@ -1,31 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Dict, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
MofNCompleteColumn,
|
||||
Progress,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
|
||||
from .global_logger import logger
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -133,19 +134,20 @@ class EmbeddingStore:
|
||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
||||
# 创建新的事件循环并在完成后立即关闭
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
"""以同步方式获取单条字符串的嵌入向量。
|
||||
|
||||
Args:
|
||||
s: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
List[float]: 嵌入向量;失败时返回空列表。
|
||||
"""
|
||||
try:
|
||||
# 创建新的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
# 使用新的事件循环运行异步方法
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_result = embedding_client.embed_text_sync(s)
|
||||
embedding = embedding_result.embedding
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
@@ -157,17 +159,15 @@ class EmbeddingStore:
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
return []
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
self,
|
||||
strs: List[str],
|
||||
chunk_size: int = 10,
|
||||
max_workers: int = 10,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
"""使用多线程批量获取嵌入向量。
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
@@ -190,53 +190,42 @@ class EmbeddingStore:
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
|
||||
"""处理单个数据块。
|
||||
|
||||
# 为每个线程创建独立的服务层实例
|
||||
from src.services.llm_service import LLMServiceClient
|
||||
Args:
|
||||
chunk_data: 数据块起始索引与字符串列表。
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
|
||||
"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results: List[Tuple[int, str, List[float]]] = []
|
||||
|
||||
try:
|
||||
# 创建线程专用的服务层实例
|
||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
||||
embedding = embedding_result.embedding
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name="embedding",
|
||||
request_type="embedding",
|
||||
)
|
||||
embedding_results = embedding_client.embed_texts_sync(
|
||||
chunk_strs,
|
||||
max_concurrent=1,
|
||||
)
|
||||
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
|
||||
embedding = embedding_result.embedding
|
||||
if embedding and len(embedding) > 0:
|
||||
chunk_results.append((start_idx + i, s, embedding))
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
logger.error(f"创建 EmbeddingService 实例失败: {e}")
|
||||
for i, s in enumerate(chunk_strs):
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user