feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑
This commit is contained in:
@@ -169,6 +169,33 @@ def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> Non
|
|||||||
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
|
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None:
|
||||||
|
"""Runner 补齐配置时应尽量保留现有 config.toml 注释。"""
|
||||||
|
|
||||||
|
plugin = _DemoConfigPlugin()
|
||||||
|
runner = PluginRunner(
|
||||||
|
host_address="ipc://unused",
|
||||||
|
session_token="session-token",
|
||||||
|
plugin_dirs=[],
|
||||||
|
)
|
||||||
|
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
|
||||||
|
config_path = tmp_path / "config.toml"
|
||||||
|
config_path.write_text(
|
||||||
|
'# 插件配置头注释\n[plugin]\nenabled = false # 启用开关注释\n',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
runner._apply_plugin_config(cast(Any, meta))
|
||||||
|
|
||||||
|
config_text = config_path.read_text(encoding="utf-8")
|
||||||
|
assert "# 插件配置头注释" in config_text
|
||||||
|
assert "# 启用开关注释" in config_text
|
||||||
|
|
||||||
|
with config_path.open("rb") as handle:
|
||||||
|
saved_config = tomllib.load(handle)
|
||||||
|
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
|
||||||
|
|
||||||
|
|
||||||
def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None:
|
def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None:
|
||||||
"""组件查询服务应支持按插件 ID 返回配置 Schema。"""
|
"""组件查询服务应支持按插件 ID 返回配置 Schema。"""
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,32 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import math
|
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Dict, List, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, List, Tuple
|
||||||
|
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
# import tqdm
|
|
||||||
import faiss
|
|
||||||
|
|
||||||
from .utils.hash import get_sha256
|
|
||||||
from .global_logger import logger
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
Progress,
|
|
||||||
BarColumn,
|
BarColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
Progress,
|
||||||
|
SpinnerColumn,
|
||||||
|
TextColumn,
|
||||||
TimeElapsedColumn,
|
TimeElapsedColumn,
|
||||||
TimeRemainingColumn,
|
TimeRemainingColumn,
|
||||||
TaskProgressColumn,
|
TaskProgressColumn,
|
||||||
MofNCompleteColumn,
|
|
||||||
SpinnerColumn,
|
|
||||||
TextColumn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.services.embedding_service import EmbeddingServiceClient
|
||||||
|
|
||||||
|
from .global_logger import logger
|
||||||
|
from .utils.hash import get_sha256
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
@@ -133,19 +134,20 @@ class EmbeddingStore:
|
|||||||
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
return [f"{namespace}-{get_sha256(t)}" for t in texts]
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
|
"""以同步方式获取单条字符串的嵌入向量。
|
||||||
# 创建新的事件循环并在完成后立即关闭
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: 待编码的文本内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[float]: 嵌入向量;失败时返回空列表。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# 创建新的服务层实例
|
embedding_client = EmbeddingServiceClient(
|
||||||
from src.services.llm_service import LLMServiceClient
|
task_name="embedding",
|
||||||
|
request_type="embedding",
|
||||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
)
|
||||||
|
embedding_result = embedding_client.embed_text_sync(s)
|
||||||
# 使用新的事件循环运行异步方法
|
|
||||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
|
||||||
embedding = embedding_result.embedding
|
embedding = embedding_result.embedding
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
if embedding and len(embedding) > 0:
|
||||||
@@ -157,17 +159,15 @@ class EmbeddingStore:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||||
return []
|
return []
|
||||||
finally:
|
|
||||||
# 确保事件循环被正确关闭
|
|
||||||
try:
|
|
||||||
loop.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_embeddings_batch_threaded(
|
def _get_embeddings_batch_threaded(
|
||||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
self,
|
||||||
|
strs: List[str],
|
||||||
|
chunk_size: int = 10,
|
||||||
|
max_workers: int = 10,
|
||||||
|
progress_callback: Callable[[int], None] | None = None,
|
||||||
) -> List[Tuple[str, List[float]]]:
|
) -> List[Tuple[str, List[float]]]:
|
||||||
"""使用多线程批量获取嵌入向量
|
"""使用多线程批量获取嵌入向量。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strs: 要获取嵌入的字符串列表
|
strs: 要获取嵌入的字符串列表
|
||||||
@@ -190,53 +190,42 @@ class EmbeddingStore:
|
|||||||
# 结果存储,使用字典按索引存储以保证顺序
|
# 结果存储,使用字典按索引存储以保证顺序
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
def process_chunk(chunk_data):
|
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
|
||||||
"""处理单个数据块的函数"""
|
"""处理单个数据块。
|
||||||
start_idx, chunk_strs = chunk_data
|
|
||||||
chunk_results = []
|
|
||||||
|
|
||||||
# 为每个线程创建独立的服务层实例
|
Args:
|
||||||
from src.services.llm_service import LLMServiceClient
|
chunk_data: 数据块起始索引与字符串列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
|
||||||
|
"""
|
||||||
|
start_idx, chunk_strs = chunk_data
|
||||||
|
chunk_results: List[Tuple[int, str, List[float]]] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建线程专用的服务层实例
|
embedding_client = EmbeddingServiceClient(
|
||||||
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
|
task_name="embedding",
|
||||||
|
request_type="embedding",
|
||||||
for i, s in enumerate(chunk_strs):
|
)
|
||||||
try:
|
embedding_results = embedding_client.embed_texts_sync(
|
||||||
# 在线程中创建独立的事件循环
|
chunk_strs,
|
||||||
loop = asyncio.new_event_loop()
|
max_concurrent=1,
|
||||||
asyncio.set_event_loop(loop)
|
)
|
||||||
try:
|
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
|
||||||
embedding_result = loop.run_until_complete(llm.embed_text(s))
|
embedding = embedding_result.embedding
|
||||||
embedding = embedding_result.embedding
|
if embedding and len(embedding) > 0:
|
||||||
finally:
|
chunk_results.append((start_idx + i, s, embedding))
|
||||||
loop.close()
|
else:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
if embedding and len(embedding) > 0:
|
|
||||||
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
|
||||||
else:
|
|
||||||
logger.error(f"获取嵌入失败: {s}")
|
|
||||||
chunk_results.append((start_idx + i, s, []))
|
|
||||||
|
|
||||||
# 每完成一个嵌入立即更新进度
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
|
||||||
chunk_results.append((start_idx + i, s, []))
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
|
||||||
# 即使失败也要更新进度
|
if progress_callback:
|
||||||
if progress_callback:
|
progress_callback(1)
|
||||||
progress_callback(1)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建LLM实例失败: {e}")
|
logger.error(f"创建 EmbeddingService 实例失败: {e}")
|
||||||
# 如果创建LLM实例失败,返回空结果
|
|
||||||
for i, s in enumerate(chunk_strs):
|
for i, s in enumerate(chunk_strs):
|
||||||
chunk_results.append((start_idx + i, s, []))
|
chunk_results.append((start_idx + i, s, []))
|
||||||
# 即使失败也要更新进度
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(1)
|
progress_callback(1)
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
|||||||
from src.chat.message_receive.message import SessionMessage
|
from src.chat.message_receive.message import SessionMessage
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.services.llm_service import LLMServiceClient
|
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
|
from src.services.embedding_service import EmbeddingServiceClient
|
||||||
|
|
||||||
from .typo_generator import ChineseTypoGenerator
|
from .typo_generator import ChineseTypoGenerator
|
||||||
|
|
||||||
@@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
|
|||||||
return is_mentioned, is_at, reply_probability
|
return is_mentioned, is_at, reply_probability
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
|
async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]:
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的嵌入向量。
|
||||||
# 每次都创建新的服务层实例以避免事件循环冲突
|
|
||||||
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
|
Args:
|
||||||
|
text: 待编码的文本内容。
|
||||||
|
request_type: 当前请求的业务类型标识。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。
|
||||||
|
"""
|
||||||
|
embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type)
|
||||||
try:
|
try:
|
||||||
embedding_result = await llm.embed_text(text)
|
embedding_result = await embedding_client.embed_text(text)
|
||||||
embedding = embedding_result.embedding
|
embedding = embedding_result.embedding
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取embedding失败: {str(e)}")
|
logger.error(f"获取embedding失败: {str(e)}")
|
||||||
|
|||||||
19
src/common/data_models/embedding_service_data_models.py
Normal file
19
src/common/data_models/embedding_service_data_models.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Embedding 服务层共享数据模型。"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from src.common.data_models import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class EmbeddingResult(BaseDataModel):
|
||||||
|
"""Embedding 服务层统一响应对象。"""
|
||||||
|
|
||||||
|
embedding: List[float] = field(default_factory=list)
|
||||||
|
model_name: str = field(default_factory=str)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EmbeddingResult",
|
||||||
|
]
|
||||||
@@ -458,6 +458,62 @@ class PluginRunner:
|
|||||||
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
|
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
|
||||||
return normalized_config, False
|
return normalized_config, False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_plugin_config_document(target: Any, source: Any) -> None:
|
||||||
|
"""递归更新现有 TOML 文档,尽量保留原注释与格式。
|
||||||
|
|
||||||
|
这里采用“更新已有键、补充缺失键”的策略,而不是直接整体重写,
|
||||||
|
这样插件启动时因补齐默认配置触发落盘时,可以尽量保留用户手写的注释。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: 现有的 TOML 文档或表对象。
|
||||||
|
source: 最新的配置字典。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
for key, value in source.items():
|
||||||
|
if key in target:
|
||||||
|
target_value = target[key]
|
||||||
|
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||||
|
PluginRunner._merge_plugin_config_document(target_value, value)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target[key] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
target[key] = tomlkit.item(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
target[key] = value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_extra_config_keys(existing_config: Any, latest_config: Any) -> bool:
|
||||||
|
"""判断现有配置中是否包含新配置不存在的键。
|
||||||
|
|
||||||
|
如果插件归一化后的结果删除了某些旧键,就需要回退到完整重写,
|
||||||
|
否则仅做增量合并会把旧键残留在文件里。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_config: 现有配置字典。
|
||||||
|
latest_config: 最新配置字典。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否存在需要通过整文件重写才能删除的旧键。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(existing_config, dict) or not isinstance(latest_config, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for key, existing_value in existing_config.items():
|
||||||
|
if key not in latest_config:
|
||||||
|
return True
|
||||||
|
if PluginRunner._has_extra_config_keys(existing_value, latest_config[key]):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
|
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
|
||||||
"""根据配置内容判断插件是否应被视为启用。
|
"""根据配置内容判断插件是否应被视为启用。
|
||||||
@@ -496,6 +552,19 @@ class PluginRunner:
|
|||||||
|
|
||||||
config_path = Path(plugin_dir) / "config.toml"
|
config_path = Path(plugin_dir) / "config.toml"
|
||||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with config_path.open("r", encoding="utf-8") as handle:
|
||||||
|
existing_document = tomlkit.load(handle)
|
||||||
|
existing_config = existing_document.unwrap()
|
||||||
|
if not PluginRunner._has_extra_config_keys(existing_config, config_data):
|
||||||
|
PluginRunner._merge_plugin_config_document(existing_document, config_data)
|
||||||
|
with config_path.open("w", encoding="utf-8") as handle:
|
||||||
|
handle.write(tomlkit.dumps(existing_document))
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"保留插件配置注释失败,将回退为整文件重写: {config_path}: {exc}")
|
||||||
|
|
||||||
with config_path.open("w", encoding="utf-8") as handle:
|
with config_path.open("w", encoding="utf-8") as handle:
|
||||||
handle.write(tomlkit.dumps(config_data))
|
handle.write(tomlkit.dumps(config_data))
|
||||||
|
|
||||||
|
|||||||
160
src/services/embedding_service.py
Normal file
160
src/services/embedding_service.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""Embedding 服务层。
|
||||||
|
|
||||||
|
该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到
|
||||||
|
`src.llm_models` 中的底层嵌入调度器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Coroutine, List, TypeVar
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.llm_models.utils_model import LLMOrchestrator
|
||||||
|
from src.services.service_task_resolver import resolve_task_name
|
||||||
|
|
||||||
|
logger = get_logger("embedding_service")
|
||||||
|
|
||||||
|
_CoroutineReturnT = TypeVar("_CoroutineReturnT")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingServiceClient:
|
||||||
|
"""面向上层模块的 Embedding 服务对象式门面。"""
|
||||||
|
|
||||||
|
def __init__(self, task_name: str = "embedding", request_type: str = "") -> None:
|
||||||
|
"""初始化 Embedding 服务门面。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||||
|
request_type: 当前请求的业务类型标识。
|
||||||
|
"""
|
||||||
|
self.task_name = resolve_task_name(task_name)
|
||||||
|
self.request_type = request_type
|
||||||
|
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||||
|
|
||||||
|
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||||
|
"""生成单条文本的嵌入向量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_input: 待编码的文本内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResult: 统一嵌入结果对象。
|
||||||
|
"""
|
||||||
|
raw_result = await self._orchestrator.get_embedding(embedding_input)
|
||||||
|
return EmbeddingResult(
|
||||||
|
embedding=list(raw_result.embedding),
|
||||||
|
model_name=raw_result.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def embed_texts(
|
||||||
|
self,
|
||||||
|
embedding_inputs: List[str],
|
||||||
|
max_concurrent: int | None = None,
|
||||||
|
) -> List[EmbeddingResult]:
|
||||||
|
"""批量生成文本嵌入向量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_inputs: 待编码的文本列表。
|
||||||
|
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||||
|
"""
|
||||||
|
if not embedding_inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
safe_max_concurrent = max(1, int(max_concurrent or 1))
|
||||||
|
if safe_max_concurrent == 1:
|
||||||
|
results: List[EmbeddingResult] = []
|
||||||
|
for embedding_input in embedding_inputs:
|
||||||
|
results.append(await self.embed_text(embedding_input))
|
||||||
|
return results
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(safe_max_concurrent)
|
||||||
|
|
||||||
|
async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]:
|
||||||
|
"""执行单条嵌入并保留原始顺序索引。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: 原始输入索引。
|
||||||
|
embedding_input: 待编码的文本内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。
|
||||||
|
"""
|
||||||
|
async with semaphore:
|
||||||
|
result = await self.embed_text(embedding_input)
|
||||||
|
return index, result
|
||||||
|
|
||||||
|
ordered_results = await asyncio.gather(
|
||||||
|
*[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)]
|
||||||
|
)
|
||||||
|
ordered_results.sort(key=lambda item: item[0])
|
||||||
|
return [result for _, result in ordered_results]
|
||||||
|
|
||||||
|
def embed_text_sync(self, embedding_input: str) -> EmbeddingResult:
|
||||||
|
"""以同步方式生成单条文本的嵌入向量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_input: 待编码的文本内容。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResult: 统一嵌入结果对象。
|
||||||
|
"""
|
||||||
|
return self._run_coroutine_sync(self.embed_text(embedding_input))
|
||||||
|
|
||||||
|
def embed_texts_sync(
|
||||||
|
self,
|
||||||
|
embedding_inputs: List[str],
|
||||||
|
max_concurrent: int | None = None,
|
||||||
|
) -> List[EmbeddingResult]:
|
||||||
|
"""以同步方式批量生成文本嵌入向量。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_inputs: 待编码的文本列表。
|
||||||
|
max_concurrent: 最大并发数;未提供时按串行执行。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
|
||||||
|
"""
|
||||||
|
return self._run_coroutine_sync(
|
||||||
|
self.embed_texts(
|
||||||
|
embedding_inputs,
|
||||||
|
max_concurrent=max_concurrent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT:
|
||||||
|
"""在独立事件循环中执行协程。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coroutine: 需要同步执行的协程对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_CoroutineReturnT: 协程返回值。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当前线程已有运行中的事件循环时抛出。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口")
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
return loop.run_until_complete(coroutine)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}")
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
loop.close()
|
||||||
@@ -8,9 +8,9 @@ from typing import Any, Dict, List, Tuple
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||||
from src.common.data_models.llm_service_data_models import (
|
from src.common.data_models.llm_service_data_models import (
|
||||||
LLMAudioTranscriptionResult,
|
LLMAudioTranscriptionResult,
|
||||||
LLMEmbeddingResult,
|
|
||||||
LLMGenerationOptions,
|
LLMGenerationOptions,
|
||||||
LLMImageOptions,
|
LLMImageOptions,
|
||||||
LLMResponseResult,
|
LLMResponseResult,
|
||||||
@@ -21,15 +21,20 @@ from src.common.data_models.llm_service_data_models import (
|
|||||||
PromptMessage,
|
PromptMessage,
|
||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import config_manager
|
|
||||||
from src.config.model_configs import TaskConfig
|
|
||||||
from src.llm_models.model_client.base_client import BaseClient
|
from src.llm_models.model_client.base_client import BaseClient
|
||||||
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
from src.llm_models.utils_model import LLMOrchestrator
|
from src.llm_models.utils_model import LLMOrchestrator
|
||||||
|
from src.services.embedding_service import EmbeddingServiceClient
|
||||||
|
from src.services.service_task_resolver import (
|
||||||
|
get_available_models as _get_available_models,
|
||||||
|
resolve_task_name as _resolve_task_name,
|
||||||
|
resolve_task_name_from_model_config as _resolve_task_name_from_model_config,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger("llm_service")
|
logger = get_logger("llm_service")
|
||||||
|
|
||||||
|
|
||||||
class LLMServiceClient:
|
class LLMServiceClient:
|
||||||
"""面向上层模块的 LLM 服务对象式门面。
|
"""面向上层模块的 LLM 服务对象式门面。
|
||||||
|
|
||||||
@@ -38,7 +43,7 @@ class LLMServiceClient:
|
|||||||
- `generate_response_with_messages`
|
- `generate_response_with_messages`
|
||||||
- `generate_response_for_image`
|
- `generate_response_for_image`
|
||||||
- `transcribe_audio`
|
- `transcribe_audio`
|
||||||
- `embed_text`
|
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||||
@@ -48,7 +53,7 @@ class LLMServiceClient:
|
|||||||
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
|
||||||
request_type: 当前请求的业务类型标识。
|
request_type: 当前请求的业务类型标识。
|
||||||
"""
|
"""
|
||||||
self.task_name = resolve_task_name(task_name)
|
self.task_name = _resolve_task_name(task_name)
|
||||||
self.request_type = request_type
|
self.request_type = request_type
|
||||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||||
|
|
||||||
@@ -169,41 +174,29 @@ class LLMServiceClient:
|
|||||||
"""
|
"""
|
||||||
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
return await self._orchestrator.generate_response_for_voice(voice_base64)
|
||||||
|
|
||||||
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
|
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
|
||||||
"""生成文本嵌入向量。
|
"""兼容旧调用的文本嵌入入口。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding_input: 待编码的文本。
|
embedding_input: 待编码的文本。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMEmbeddingResult: 向量生成结果对象。
|
EmbeddingResult: 向量生成结果对象。
|
||||||
"""
|
"""
|
||||||
return await self._orchestrator.get_embedding(embedding_input)
|
embedding_client = EmbeddingServiceClient(
|
||||||
|
task_name=self.task_name,
|
||||||
|
request_type=self.request_type,
|
||||||
|
)
|
||||||
|
return await embedding_client.embed_text(embedding_input)
|
||||||
|
|
||||||
|
|
||||||
def get_available_models() -> Dict[str, TaskConfig]:
|
def get_available_models() -> Dict[str, Any]:
|
||||||
"""获取所有可用模型配置。
|
"""获取所有可用模型配置。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
|
Dict[str, Any]: 以模型任务名为键的配置映射。
|
||||||
"""
|
"""
|
||||||
try:
|
return _get_available_models()
|
||||||
models = config_manager.get_model_config().model_task_config
|
|
||||||
available_models: Dict[str, TaskConfig] = {}
|
|
||||||
for attr_name in dir(models):
|
|
||||||
if attr_name.startswith("__"):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
attr_value = getattr(models, attr_name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
|
|
||||||
continue
|
|
||||||
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
|
||||||
available_models[attr_name] = attr_value
|
|
||||||
return available_models
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_task_name(task_name: str = "") -> str:
|
def resolve_task_name(task_name: str = "") -> str:
|
||||||
@@ -214,75 +207,24 @@ def resolve_task_name(task_name: str = "") -> str:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 解析得到的任务配置名。
|
str: 解析得到的任务配置名。
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 当前没有任何可用模型配置。
|
|
||||||
ValueError: 指定名称不存在时抛出。
|
|
||||||
"""
|
"""
|
||||||
models = get_available_models()
|
return _resolve_task_name(task_name)
|
||||||
if not models:
|
|
||||||
raise RuntimeError("没有可用的模型配置")
|
|
||||||
normalized_task_name = task_name.strip()
|
|
||||||
if not normalized_task_name:
|
|
||||||
return next(iter(models.keys()))
|
|
||||||
if normalized_task_name not in models:
|
|
||||||
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
|
||||||
return normalized_task_name
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||||
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
|
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
|
||||||
|
|
||||||
该方法用于兼容仍以 `model_config` 传参的调用方:
|
|
||||||
1. 优先使用显式给出的 `preferred_task_name`;
|
|
||||||
2. 其次匹配对象同一性;
|
|
||||||
3. 再尝试按 `model_list` 精确匹配;
|
|
||||||
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_config: 旧调用方持有的任务配置对象。
|
model_config: 旧调用方持有的任务配置对象。
|
||||||
preferred_task_name: 候选任务名(可选)。
|
preferred_task_name: 候选任务名(可选)。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
|
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: 当前没有可用模型配置。
|
|
||||||
ValueError: 无法解析任何可用任务名时抛出。
|
|
||||||
"""
|
"""
|
||||||
models = get_available_models()
|
return _resolve_task_name_from_model_config(
|
||||||
if not models:
|
model_config=model_config,
|
||||||
raise RuntimeError("没有可用的模型配置")
|
preferred_task_name=preferred_task_name,
|
||||||
|
)
|
||||||
normalized_preferred = str(preferred_task_name or "").strip()
|
|
||||||
if normalized_preferred and normalized_preferred in models:
|
|
||||||
return normalized_preferred
|
|
||||||
|
|
||||||
for task_name, task_cfg in models.items():
|
|
||||||
if task_cfg is model_config:
|
|
||||||
return task_name
|
|
||||||
|
|
||||||
requested_model_list_raw = getattr(model_config, "model_list", [])
|
|
||||||
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
|
||||||
if requested_model_list:
|
|
||||||
for task_name, task_cfg in models.items():
|
|
||||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
|
||||||
if candidate_list == requested_model_list:
|
|
||||||
return task_name
|
|
||||||
|
|
||||||
for requested_model in requested_model_list:
|
|
||||||
for task_name, task_cfg in models.items():
|
|
||||||
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
|
||||||
if requested_model in candidate_list:
|
|
||||||
logger.info(
|
|
||||||
"[LLMService] 旧版 model_config 未命中任务配置,"
|
|
||||||
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
|
||||||
)
|
|
||||||
return task_name
|
|
||||||
|
|
||||||
if normalized_preferred:
|
|
||||||
logger.warning(f"[LLMService] 无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
|
||||||
return resolve_task_name("")
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_role(role_name: str) -> RoleType:
|
def _normalize_role(role_name: str) -> RoleType:
|
||||||
|
|||||||
108
src/services/service_task_resolver.py
Normal file
108
src/services/service_task_resolver.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""服务层模型任务解析工具。"""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import config_manager
|
||||||
|
from src.config.model_configs import TaskConfig
|
||||||
|
|
||||||
|
logger = get_logger("service_task_resolver")
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_models() -> Dict[str, TaskConfig]:
|
||||||
|
"""获取当前所有可用的模型任务配置。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, TaskConfig]: 以任务名为键的可用任务配置映射。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
models = config_manager.get_model_config().model_task_config
|
||||||
|
available_models: Dict[str, TaskConfig] = {}
|
||||||
|
for attr_name in dir(models):
|
||||||
|
if attr_name.startswith("__"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
attr_value = getattr(models, attr_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"获取模型任务配置属性 {attr_name} 失败: {exc}")
|
||||||
|
continue
|
||||||
|
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
|
||||||
|
available_models[attr_name] = attr_value
|
||||||
|
return available_models
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"获取可用模型配置失败: {exc}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_task_name(task_name: str = "") -> str:
|
||||||
|
"""根据任务名解析实际可用的模型任务名称。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_name: 目标任务名;为空时返回首个可用任务。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 解析后的模型任务名。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||||
|
ValueError: 指定任务名不存在时抛出。
|
||||||
|
"""
|
||||||
|
models = get_available_models()
|
||||||
|
if not models:
|
||||||
|
raise RuntimeError("没有可用的模型配置")
|
||||||
|
|
||||||
|
normalized_task_name = task_name.strip()
|
||||||
|
if not normalized_task_name:
|
||||||
|
return next(iter(models.keys()))
|
||||||
|
if normalized_task_name not in models:
|
||||||
|
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
|
||||||
|
return normalized_task_name
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
|
||||||
|
"""根据旧版模型配置对象解析任务名。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_config: 旧调用方持有的任务配置对象。
|
||||||
|
preferred_task_name: 候选任务名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 解析后的模型任务名。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: 当前没有任何可用模型配置时抛出。
|
||||||
|
ValueError: 无法解析任何可用任务名时抛出。
|
||||||
|
"""
|
||||||
|
models = get_available_models()
|
||||||
|
if not models:
|
||||||
|
raise RuntimeError("没有可用的模型配置")
|
||||||
|
|
||||||
|
normalized_preferred = str(preferred_task_name or "").strip()
|
||||||
|
if normalized_preferred and normalized_preferred in models:
|
||||||
|
return normalized_preferred
|
||||||
|
|
||||||
|
for task_name, task_cfg in models.items():
|
||||||
|
if task_cfg is model_config:
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
requested_model_list_raw = getattr(model_config, "model_list", [])
|
||||||
|
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
|
||||||
|
if requested_model_list:
|
||||||
|
for task_name, task_cfg in models.items():
|
||||||
|
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||||
|
if candidate_list == requested_model_list:
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
for requested_model in requested_model_list:
|
||||||
|
for task_name, task_cfg in models.items():
|
||||||
|
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
|
||||||
|
if requested_model in candidate_list:
|
||||||
|
logger.info(
|
||||||
|
"旧版 model_config 未命中任务配置,"
|
||||||
|
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
|
||||||
|
)
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
if normalized_preferred:
|
||||||
|
logger.warning(f"无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}")
|
||||||
|
return resolve_task_name("")
|
||||||
Reference in New Issue
Block a user