feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑
This commit is contained in:
160
src/services/embedding_service.py
Normal file
160
src/services/embedding_service.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Embedding 服务层。
|
||||
|
||||
该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到
|
||||
`src.llm_models` 中的底层嵌入调度器。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Coroutine, List, TypeVar
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.service_task_resolver import resolve_task_name
|
||||
|
||||
logger = get_logger("embedding_service")
|
||||
|
||||
_CoroutineReturnT = TypeVar("_CoroutineReturnT")
|
||||
|
||||
|
||||
class EmbeddingServiceClient:
|
||||
"""面向上层模块的 Embedding 服务对象式门面。"""
|
||||
|
||||
def __init__(self, task_name: str = "embedding", request_type: str = "") -> None:
|
||||
"""初始化 Embedding 服务门面。
|
||||
|
||||
Args:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""生成单条文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 统一嵌入结果对象。
|
||||
"""
|
||||
raw_result = await self._orchestrator.get_embedding(embedding_input)
|
||||
return EmbeddingResult(
|
||||
embedding=list(raw_result.embedding),
|
||||
model_name=raw_result.model_name,
|
||||
)
|
||||
|
||||
async def embed_texts(
|
||||
self,
|
||||
embedding_inputs: List[str],
|
||||
max_concurrent: int | None = None,
|
||||
) -> List[EmbeddingResult]:
|
||||
"""批量生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_inputs: 待编码的文本列表。
|
||||
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||
|
||||
Returns:
|
||||
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||
"""
|
||||
if not embedding_inputs:
|
||||
return []
|
||||
|
||||
safe_max_concurrent = max(1, int(max_concurrent or 1))
|
||||
if safe_max_concurrent == 1:
|
||||
results: List[EmbeddingResult] = []
|
||||
for embedding_input in embedding_inputs:
|
||||
results.append(await self.embed_text(embedding_input))
|
||||
return results
|
||||
|
||||
semaphore = asyncio.Semaphore(safe_max_concurrent)
|
||||
|
||||
async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]:
|
||||
"""执行单条嵌入并保留原始顺序索引。
|
||||
|
||||
Args:
|
||||
index: 原始输入索引。
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。
|
||||
"""
|
||||
async with semaphore:
|
||||
result = await self.embed_text(embedding_input)
|
||||
return index, result
|
||||
|
||||
ordered_results = await asyncio.gather(
|
||||
*[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)]
|
||||
)
|
||||
ordered_results.sort(key=lambda item: item[0])
|
||||
return [result for _, result in ordered_results]
|
||||
|
||||
def embed_text_sync(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""以同步方式生成单条文本的嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本内容。
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 统一嵌入结果对象。
|
||||
"""
|
||||
return self._run_coroutine_sync(self.embed_text(embedding_input))
|
||||
|
||||
def embed_texts_sync(
|
||||
self,
|
||||
embedding_inputs: List[str],
|
||||
max_concurrent: int | None = None,
|
||||
) -> List[EmbeddingResult]:
|
||||
"""以同步方式批量生成文本嵌入向量。
|
||||
|
||||
Args:
|
||||
embedding_inputs: 待编码的文本列表。
|
||||
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||
|
||||
Returns:
|
||||
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||
"""
|
||||
return self._run_coroutine_sync(
|
||||
self.embed_texts(
|
||||
embedding_inputs,
|
||||
max_concurrent=max_concurrent,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT:
|
||||
"""在独立事件循环中执行协程。
|
||||
|
||||
Args:
|
||||
coroutine: 需要同步执行的协程对象。
|
||||
|
||||
Returns:
|
||||
_CoroutineReturnT: 协程返回值。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前线程已有运行中的事件循环时抛出。
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(coroutine)
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
except Exception as exc:
|
||||
logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}")
|
||||
asyncio.set_event_loop(None)
|
||||
loop.close()
|
||||
@@ -8,9 +8,9 @@ from typing import Any, Dict, List, Tuple
|
||||
|
||||
import json
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
from src.common.data_models.llm_service_data_models import (
|
||||
LLMAudioTranscriptionResult,
|
||||
LLMEmbeddingResult,
|
||||
LLMGenerationOptions,
|
||||
LLMImageOptions,
|
||||
LLMResponseResult,
|
||||
@@ -21,15 +21,20 @@ from src.common.data_models.llm_service_data_models import (
|
||||
PromptMessage,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
from src.services.service_task_resolver import (
|
||||
get_available_models as _get_available_models,
|
||||
resolve_task_name as _resolve_task_name,
|
||||
resolve_task_name_from_model_config as _resolve_task_name_from_model_config,
|
||||
)
|
||||
|
||||
logger = get_logger("llm_service")
|
||||
|
||||
|
||||
class LLMServiceClient:
|
||||
"""面向上层模块的 LLM 服务对象式门面。
|
||||
|
||||
@@ -38,7 +43,7 @@ class LLMServiceClient:
|
||||
- `generate_response_with_messages`
|
||||
- `generate_response_for_image`
|
||||
- `transcribe_audio`
|
||||
- `embed_text`
|
||||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
@@ -48,7 +53,7 @@ class LLMServiceClient:
|
||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||
request_type: 当前请求的业务类型标识。
|
||||
"""
|
||||
self.task_name = resolve_task_name(task_name)
|
||||
self.task_name = _resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@@ -169,41 +174,29 @@ class LLMServiceClient:
|
||||
"""
|
||||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||||
|
||||
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
|
||||
"""生成文本嵌入向量。
|
||||
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||
"""兼容旧调用的文本嵌入入口。
|
||||
|
||||
Args:
|
||||
embedding_input: 待编码的文本。
|
||||
|
||||
Returns:
|
||||
LLMEmbeddingResult: 向量生成结果对象。
|
||||
EmbeddingResult: 向量生成结果对象。
|
||||
"""
|
||||
return await self._orchestrator.get_embedding(embedding_input)
|
||||
embedding_client = EmbeddingServiceClient(
|
||||
task_name=self.task_name,
|
||||
request_type=self.request_type,
|
||||
)
|
||||
return await embedding_client.embed_text(embedding_input)
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
def get_available_models() -> Dict[str, Any]:
|
||||
"""获取所有可用模型配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
|
||||
Dict[str, Any]: 以模型任务名为键的配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
|
||||
return {}
|
||||
return _get_available_models()
|
||||
|
||||
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
@@ -214,75 +207,24 @@ def resolve_task_name(task_name: str = "") -> str:
|
||||
|
||||
Returns:
|
||||
str: 解析得到的任务配置名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置。
|
||||
ValueError: 指定名称不存在时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
return _resolve_task_name(task_name)
|
||||
|
||||
|
||||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
|
||||
|
||||
该方法用于兼容仍以 `model_config` 传参的调用方:
|
||||
1. 优先使用显式给出的 `preferred_task_name`;
|
||||
2. 其次匹配对象同一性;
|
||||
3. 再尝试按 `model_list` 精确匹配;
|
||||
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
|
||||
|
||||
Args:
|
||||
model_config: 旧调用方持有的任务配置对象。
|
||||
preferred_task_name: 候选任务名(可选)。
|
||||
|
||||
Returns:
|
||||
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有可用模型配置。
|
||||
ValueError: 无法解析任何可用任务名时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_preferred = str(preferred_task_name or "").strip()
|
||||
if normalized_preferred and normalized_preferred in models:
|
||||
return normalized_preferred
|
||||
|
||||
for task_name, task_cfg in models.items():
|
||||
if task_cfg is model_config:
|
||||
return task_name
|
||||
|
||||
requested_model_list_raw = getattr(model_config, "model_list", [])
|
||||
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
||||
if requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if candidate_list == requested_model_list:
|
||||
return task_name
|
||||
|
||||
for requested_model in requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if requested_model in candidate_list:
|
||||
logger.info(
|
||||
"[LLMService] 旧版 model_config 未命中任务配置,"
|
||||
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
||||
)
|
||||
return task_name
|
||||
|
||||
if normalized_preferred:
|
||||
logger.warning(f"[LLMService] 无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
||||
return resolve_task_name("")
|
||||
return _resolve_task_name_from_model_config(
|
||||
model_config=model_config,
|
||||
preferred_task_name=preferred_task_name,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_role(role_name: str) -> RoleType:
|
||||
|
||||
108
src/services/service_task_resolver.py
Normal file
108
src/services/service_task_resolver.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""服务层模型任务解析工具。"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
logger = get_logger("service_task_resolver")
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取当前所有可用的模型任务配置。
|
||||
|
||||
Returns:
|
||||
Dict[str, TaskConfig]: 以任务名为键的可用任务配置映射。
|
||||
"""
|
||||
try:
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
available_models: Dict[str, TaskConfig] = {}
|
||||
for attr_name in dir(models):
|
||||
if attr_name.startswith("__"):
|
||||
continue
|
||||
try:
|
||||
attr_value = getattr(models, attr_name)
|
||||
except Exception as exc:
|
||||
logger.debug(f"获取模型任务配置属性 {attr_name} 失败: {exc}")
|
||||
continue
|
||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||
available_models[attr_name] = attr_value
|
||||
return available_models
|
||||
except Exception as exc:
|
||||
logger.error(f"获取可用模型配置失败: {exc}")
|
||||
return {}
|
||||
|
||||
|
||||
def resolve_task_name(task_name: str = "") -> str:
|
||||
"""根据任务名解析实际可用的模型任务名称。
|
||||
|
||||
Args:
|
||||
task_name: 目标任务名;为空时返回首个可用任务。
|
||||
|
||||
Returns:
|
||||
str: 解析后的模型任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||
ValueError: 指定任务名不存在时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_task_name = task_name.strip()
|
||||
if not normalized_task_name:
|
||||
return next(iter(models.keys()))
|
||||
if normalized_task_name not in models:
|
||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||
return normalized_task_name
|
||||
|
||||
|
||||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||
"""根据旧版模型配置对象解析任务名。
|
||||
|
||||
Args:
|
||||
model_config: 旧调用方持有的任务配置对象。
|
||||
preferred_task_name: 候选任务名。
|
||||
|
||||
Returns:
|
||||
str: 解析后的模型任务名。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||
ValueError: 无法解析任何可用任务名时抛出。
|
||||
"""
|
||||
models = get_available_models()
|
||||
if not models:
|
||||
raise RuntimeError("没有可用的模型配置")
|
||||
|
||||
normalized_preferred = str(preferred_task_name or "").strip()
|
||||
if normalized_preferred and normalized_preferred in models:
|
||||
return normalized_preferred
|
||||
|
||||
for task_name, task_cfg in models.items():
|
||||
if task_cfg is model_config:
|
||||
return task_name
|
||||
|
||||
requested_model_list_raw = getattr(model_config, "model_list", [])
|
||||
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
||||
if requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if candidate_list == requested_model_list:
|
||||
return task_name
|
||||
|
||||
for requested_model in requested_model_list:
|
||||
for task_name, task_cfg in models.items():
|
||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||
if requested_model in candidate_list:
|
||||
logger.info(
|
||||
"旧版 model_config 未命中任务配置,"
|
||||
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
||||
)
|
||||
return task_name
|
||||
|
||||
if normalized_preferred:
|
||||
logger.warning(f"无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
||||
return resolve_task_name("")
|
||||
Reference in New Issue
Block a user