Files
mai-bot/src/services/embedding_service.py

161 lines
5.4 KiB
Python

"""Embedding 服务层。
该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到
`src.llm_models` 中的底层嵌入调度器。
"""
from __future__ import annotations
from typing import Any, Coroutine, List, TypeVar
import asyncio
from src.common.data_models.embedding_service_data_models import EmbeddingResult
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMOrchestrator
from src.services.service_task_resolver import resolve_task_name
logger = get_logger("embedding_service")
_CoroutineReturnT = TypeVar("_CoroutineReturnT")
class EmbeddingServiceClient:
"""面向上层模块的 Embedding 服务对象式门面。"""
def __init__(self, task_name: str = "embedding", request_type: str = "") -> None:
"""初始化 Embedding 服务门面。
Args:
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
request_type: 当前请求的业务类型标识。
"""
self.task_name = resolve_task_name(task_name)
self.request_type = request_type
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
"""生成单条文本的嵌入向量。
Args:
embedding_input: 待编码的文本内容。
Returns:
EmbeddingResult: 统一嵌入结果对象。
"""
raw_result = await self._orchestrator.get_embedding(embedding_input)
return EmbeddingResult(
embedding=list(raw_result.embedding),
model_name=raw_result.model_name,
)
async def embed_texts(
self,
embedding_inputs: List[str],
max_concurrent: int | None = None,
) -> List[EmbeddingResult]:
"""批量生成文本嵌入向量。
Args:
embedding_inputs: 待编码的文本列表。
max_concurrent: 最大并发数;未提供时按串行执行。
Returns:
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
"""
if not embedding_inputs:
return []
safe_max_concurrent = max(1, int(max_concurrent or 1))
if safe_max_concurrent == 1:
results: List[EmbeddingResult] = []
for embedding_input in embedding_inputs:
results.append(await self.embed_text(embedding_input))
return results
semaphore = asyncio.Semaphore(safe_max_concurrent)
async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]:
"""执行单条嵌入并保留原始顺序索引。
Args:
index: 原始输入索引。
embedding_input: 待编码的文本内容。
Returns:
tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。
"""
async with semaphore:
result = await self.embed_text(embedding_input)
return index, result
ordered_results = await asyncio.gather(
*[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)]
)
ordered_results.sort(key=lambda item: item[0])
return [result for _, result in ordered_results]
def embed_text_sync(self, embedding_input: str) -> EmbeddingResult:
"""以同步方式生成单条文本的嵌入向量。
Args:
embedding_input: 待编码的文本内容。
Returns:
EmbeddingResult: 统一嵌入结果对象。
"""
return self._run_coroutine_sync(self.embed_text(embedding_input))
def embed_texts_sync(
self,
embedding_inputs: List[str],
max_concurrent: int | None = None,
) -> List[EmbeddingResult]:
"""以同步方式批量生成文本嵌入向量。
Args:
embedding_inputs: 待编码的文本列表。
max_concurrent: 最大并发数;未提供时按串行执行。
Returns:
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
"""
return self._run_coroutine_sync(
self.embed_texts(
embedding_inputs,
max_concurrent=max_concurrent,
)
)
@staticmethod
def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT:
"""在独立事件循环中执行协程。
Args:
coroutine: 需要同步执行的协程对象。
Returns:
_CoroutineReturnT: 协程返回值。
Raises:
RuntimeError: 当前线程已有运行中的事件循环时抛出。
"""
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口")
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception as exc:
logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}")
asyncio.set_event_loop(None)
loop.close()