feat: 添加嵌入服务层和任务解析工具,重构文本嵌入逻辑

This commit is contained in:
DrSmoothl
2026-04-03 23:35:16 +08:00
parent 8bfe3e7036
commit a2431e677e
8 changed files with 483 additions and 162 deletions

View File

@@ -169,6 +169,33 @@ def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> Non
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None:
"""Runner 补齐配置时应尽量保留现有 config.toml 注释。"""
plugin = _DemoConfigPlugin()
runner = PluginRunner(
host_address="ipc://unused",
session_token="session-token",
plugin_dirs=[],
)
meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin)
config_path = tmp_path / "config.toml"
config_path.write_text(
'# 插件配置头注释\n[plugin]\nenabled = false # 启用开关注释\n',
encoding="utf-8",
)
runner._apply_plugin_config(cast(Any, meta))
config_text = config_path.read_text(encoding="utf-8")
assert "# 插件配置头注释" in config_text
assert "# 启用开关注释" in config_text
with config_path.open("rb") as handle:
saved_config = tomllib.load(handle)
assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}}
def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None:
"""组件查询服务应支持按插件 ID 返回配置 Schema。"""

View File

@@ -1,31 +1,32 @@
from dataclasses import dataclass
import json
import os
import math
import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple
import json
import math
import os
import faiss
import numpy as np
import pandas as pd
# import tqdm
import faiss
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from src.config.config import global_config
from src.services.embedding_service import EmbeddingServiceClient
from .global_logger import logger
from .utils.hash import get_sha256
install(extra_lines=3)
@@ -133,19 +134,20 @@ class EmbeddingStore:
return [f"{namespace}-{get_sha256(t)}" for t in texts]
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
"""以同步方式获取单条字符串的嵌入向量。
Args:
s: 待编码的文本内容。
Returns:
List[float]: 嵌入向量;失败时返回空列表。
"""
try:
# 创建新的服务层实例
from src.services.llm_service import LLMServiceClient
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
# 使用新的事件循环运行异步方法
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding_client = EmbeddingServiceClient(
task_name="embedding",
request_type="embedding",
)
embedding_result = embedding_client.embed_text_sync(s)
embedding = embedding_result.embedding
if embedding and len(embedding) > 0:
@@ -157,17 +159,15 @@ class EmbeddingStore:
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
pass
def _get_embeddings_batch_threaded(
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
self,
strs: List[str],
chunk_size: int = 10,
max_workers: int = 10,
progress_callback: Callable[[int], None] | None = None,
) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
@@ -190,53 +190,42 @@ class EmbeddingStore:
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]:
"""处理单个数据块
# 为每个线程创建独立的服务层实例
from src.services.llm_service import LLMServiceClient
Args:
chunk_data: 数据块起始索引与字符串列表。
Returns:
List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。
"""
start_idx, chunk_strs = chunk_data
chunk_results: List[Tuple[int, str, List[float]]] = []
try:
# 创建线程专用的服务层实例
llm = LLMServiceClient(task_name="embedding", request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 在线程中创建独立的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding_result = loop.run_until_complete(llm.embed_text(s))
embedding = embedding_result.embedding
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
embedding_client = EmbeddingServiceClient(
task_name="embedding",
request_type="embedding",
)
embedding_results = embedding_client.embed_texts_sync(
chunk_strs,
max_concurrent=1,
)
for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)):
embedding = embedding_result.embedding
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding))
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
logger.error(f"创建 EmbeddingService 实例失败: {e}")
for i, s in enumerate(chunk_strs):
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)

View File

@@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.message import SessionMessage
from src.common.logger import get_logger
from src.config.config import global_config
from src.services.llm_service import LLMServiceClient
from src.person_info.person_info import Person
from src.services.embedding_service import EmbeddingServiceClient
from .typo_generator import ChineseTypoGenerator
@@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl
return is_mentioned, is_at, reply_probability
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# 每次都创建新的服务层实例以避免事件循环冲突
llm = LLMServiceClient(task_name="embedding", request_type=request_type)
async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]:
"""获取文本的嵌入向量
Args:
text: 待编码的文本内容。
request_type: 当前请求的业务类型标识。
Returns:
Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。
"""
embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type)
try:
embedding_result = await llm.embed_text(text)
embedding_result = await embedding_client.embed_text(text)
embedding = embedding_result.embedding
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")

View File

@@ -0,0 +1,19 @@
"""Embedding 服务层共享数据模型。"""
from dataclasses import dataclass, field
from typing import List
from src.common.data_models import BaseDataModel
@dataclass(slots=True)
class EmbeddingResult(BaseDataModel):
"""Embedding 服务层统一响应对象。"""
embedding: List[float] = field(default_factory=list)
model_name: str = field(default_factory=str)
__all__ = [
"EmbeddingResult",
]

View File

@@ -458,6 +458,62 @@ class PluginRunner:
logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}")
return normalized_config, False
@staticmethod
def _merge_plugin_config_document(target: Any, source: Any) -> None:
"""递归更新现有 TOML 文档,尽量保留原注释与格式。
这里采用“更新已有键、补充缺失键”的策略,而不是直接整体重写,
这样插件启动时因补齐默认配置触发落盘时,可以尽量保留用户手写的注释。
Args:
target: 现有的 TOML 文档或表对象。
source: 最新的配置字典。
"""
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
return
for key, value in source.items():
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, dict):
PluginRunner._merge_plugin_config_document(target_value, value)
else:
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
else:
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
@staticmethod
def _has_extra_config_keys(existing_config: Any, latest_config: Any) -> bool:
"""判断现有配置中是否包含新配置不存在的键。
如果插件归一化后的结果删除了某些旧键,就需要回退到完整重写,
否则仅做增量合并会把旧键残留在文件里。
Args:
existing_config: 现有配置字典。
latest_config: 最新配置字典。
Returns:
bool: 是否存在需要通过整文件重写才能删除的旧键。
"""
if not isinstance(existing_config, dict) or not isinstance(latest_config, dict):
return False
for key, existing_value in existing_config.items():
if key not in latest_config:
return True
if PluginRunner._has_extra_config_keys(existing_value, latest_config[key]):
return True
return False
@staticmethod
def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool:
"""根据配置内容判断插件是否应被视为启用。
@@ -496,6 +552,19 @@ class PluginRunner:
config_path = Path(plugin_dir) / "config.toml"
config_path.parent.mkdir(parents=True, exist_ok=True)
if config_path.exists():
try:
with config_path.open("r", encoding="utf-8") as handle:
existing_document = tomlkit.load(handle)
existing_config = existing_document.unwrap()
if not PluginRunner._has_extra_config_keys(existing_config, config_data):
PluginRunner._merge_plugin_config_document(existing_document, config_data)
with config_path.open("w", encoding="utf-8") as handle:
handle.write(tomlkit.dumps(existing_document))
return
except Exception as exc:
logger.warning(f"保留插件配置注释失败,将回退为整文件重写: {config_path}: {exc}")
with config_path.open("w", encoding="utf-8") as handle:
handle.write(tomlkit.dumps(config_data))

View File

@@ -0,0 +1,160 @@
"""Embedding 服务层。
该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到
`src.llm_models` 中的底层嵌入调度器。
"""
from __future__ import annotations
from typing import Any, Coroutine, List, TypeVar
import asyncio
from src.common.data_models.embedding_service_data_models import EmbeddingResult
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMOrchestrator
from src.services.service_task_resolver import resolve_task_name
logger = get_logger("embedding_service")
_CoroutineReturnT = TypeVar("_CoroutineReturnT")
class EmbeddingServiceClient:
"""面向上层模块的 Embedding 服务对象式门面。"""
def __init__(self, task_name: str = "embedding", request_type: str = "") -> None:
"""初始化 Embedding 服务门面。
Args:
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
request_type: 当前请求的业务类型标识。
"""
self.task_name = resolve_task_name(task_name)
self.request_type = request_type
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
"""生成单条文本的嵌入向量。
Args:
embedding_input: 待编码的文本内容。
Returns:
EmbeddingResult: 统一嵌入结果对象。
"""
raw_result = await self._orchestrator.get_embedding(embedding_input)
return EmbeddingResult(
embedding=list(raw_result.embedding),
model_name=raw_result.model_name,
)
async def embed_texts(
self,
embedding_inputs: List[str],
max_concurrent: int | None = None,
) -> List[EmbeddingResult]:
"""批量生成文本嵌入向量。
Args:
embedding_inputs: 待编码的文本列表。
max_concurrent: 最大并发数;未提供时按串行执行。
Returns:
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
"""
if not embedding_inputs:
return []
safe_max_concurrent = max(1, int(max_concurrent or 1))
if safe_max_concurrent == 1:
results: List[EmbeddingResult] = []
for embedding_input in embedding_inputs:
results.append(await self.embed_text(embedding_input))
return results
semaphore = asyncio.Semaphore(safe_max_concurrent)
async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]:
"""执行单条嵌入并保留原始顺序索引。
Args:
index: 原始输入索引。
embedding_input: 待编码的文本内容。
Returns:
tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。
"""
async with semaphore:
result = await self.embed_text(embedding_input)
return index, result
ordered_results = await asyncio.gather(
*[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)]
)
ordered_results.sort(key=lambda item: item[0])
return [result for _, result in ordered_results]
def embed_text_sync(self, embedding_input: str) -> EmbeddingResult:
"""以同步方式生成单条文本的嵌入向量。
Args:
embedding_input: 待编码的文本内容。
Returns:
EmbeddingResult: 统一嵌入结果对象。
"""
return self._run_coroutine_sync(self.embed_text(embedding_input))
def embed_texts_sync(
self,
embedding_inputs: List[str],
max_concurrent: int | None = None,
) -> List[EmbeddingResult]:
"""以同步方式批量生成文本嵌入向量。
Args:
embedding_inputs: 待编码的文本列表。
max_concurrent: 最大并发数;未提供时按串行执行。
Returns:
List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。
"""
return self._run_coroutine_sync(
self.embed_texts(
embedding_inputs,
max_concurrent=max_concurrent,
)
)
@staticmethod
def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT:
"""在独立事件循环中执行协程。
Args:
coroutine: 需要同步执行的协程对象。
Returns:
_CoroutineReturnT: 协程返回值。
Raises:
RuntimeError: 当前线程已有运行中的事件循环时抛出。
"""
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口")
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
finally:
try:
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception as exc:
logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}")
asyncio.set_event_loop(None)
loop.close()

View File

@@ -8,9 +8,9 @@ from typing import Any, Dict, List, Tuple
import json
from src.common.data_models.embedding_service_data_models import EmbeddingResult
from src.common.data_models.llm_service_data_models import (
LLMAudioTranscriptionResult,
LLMEmbeddingResult,
LLMGenerationOptions,
LLMImageOptions,
LLMResponseResult,
@@ -21,15 +21,20 @@ from src.common.data_models.llm_service_data_models import (
PromptMessage,
)
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMOrchestrator
from src.services.embedding_service import EmbeddingServiceClient
from src.services.service_task_resolver import (
get_available_models as _get_available_models,
resolve_task_name as _resolve_task_name,
resolve_task_name_from_model_config as _resolve_task_name_from_model_config,
)
logger = get_logger("llm_service")
class LLMServiceClient:
"""面向上层模块的 LLM 服务对象式门面。
@@ -38,7 +43,7 @@ class LLMServiceClient:
- `generate_response_with_messages`
- `generate_response_for_image`
- `transcribe_audio`
- `embed_text`
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`
"""
def __init__(self, task_name: str, request_type: str = "") -> None:
@@ -48,7 +53,7 @@ class LLMServiceClient:
task_name: 任务配置名称,对应 `model_task_config` 下的字段名。
request_type: 当前请求的业务类型标识。
"""
self.task_name = resolve_task_name(task_name)
self.task_name = _resolve_task_name(task_name)
self.request_type = request_type
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
@@ -169,41 +174,29 @@ class LLMServiceClient:
"""
return await self._orchestrator.generate_response_for_voice(voice_base64)
async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult:
"""生成文本嵌入向量
async def embed_text(self, embedding_input: str) -> EmbeddingResult:
"""兼容旧调用的文本嵌入入口
Args:
embedding_input: 待编码的文本。
Returns:
LLMEmbeddingResult: 向量生成结果对象。
EmbeddingResult: 向量生成结果对象。
"""
return await self._orchestrator.get_embedding(embedding_input)
embedding_client = EmbeddingServiceClient(
task_name=self.task_name,
request_type=self.request_type,
)
return await embedding_client.embed_text(embedding_input)
def get_available_models() -> Dict[str, TaskConfig]:
def get_available_models() -> Dict[str, Any]:
"""获取所有可用模型配置。
Returns:
Dict[str, TaskConfig]: 以模型任务名为键的配置映射。
Dict[str, Any]: 以模型任务名为键的配置映射。
"""
try:
models = config_manager.get_model_config().model_task_config
available_models: Dict[str, TaskConfig] = {}
for attr_name in dir(models):
if attr_name.startswith("__"):
continue
try:
attr_value = getattr(models, attr_name)
except Exception as exc:
logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}")
continue
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
available_models[attr_name] = attr_value
return available_models
except Exception as exc:
logger.error(f"[LLMService] 获取可用模型失败: {exc}")
return {}
return _get_available_models()
def resolve_task_name(task_name: str = "") -> str:
@@ -214,75 +207,24 @@ def resolve_task_name(task_name: str = "") -> str:
Returns:
str: 解析得到的任务配置名。
Raises:
RuntimeError: 当前没有任何可用模型配置。
ValueError: 指定名称不存在时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_task_name = task_name.strip()
if not normalized_task_name:
return next(iter(models.keys()))
if normalized_task_name not in models:
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
return normalized_task_name
return _resolve_task_name(task_name)
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
"""根据旧版 `TaskConfig` 风格参数解析可用任务名。
该方法用于兼容仍以 `model_config` 传参的调用方:
1. 优先使用显式给出的 `preferred_task_name`
2. 其次匹配对象同一性;
3. 再尝试按 `model_list` 精确匹配;
4. 最后按 `model_list` 中首个命中的模型进行近似映射。
Args:
model_config: 旧调用方持有的任务配置对象。
preferred_task_name: 候选任务名(可选)。
Returns:
str: 可用于 `LLMServiceRequest.task_name` 的任务名。
Raises:
RuntimeError: 当前没有可用模型配置。
ValueError: 无法解析任何可用任务名时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_preferred = str(preferred_task_name or "").strip()
if normalized_preferred and normalized_preferred in models:
return normalized_preferred
for task_name, task_cfg in models.items():
if task_cfg is model_config:
return task_name
requested_model_list_raw = getattr(model_config, "model_list", [])
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
if requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if candidate_list == requested_model_list:
return task_name
for requested_model in requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if requested_model in candidate_list:
logger.info(
"[LLMService] 旧版 model_config 未命中任务配置,"
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
)
return task_name
if normalized_preferred:
logger.warning(f"[LLMService] 无法映射旧版 model_config回退默认任务: preferred={normalized_preferred}")
return resolve_task_name("")
return _resolve_task_name_from_model_config(
model_config=model_config,
preferred_task_name=preferred_task_name,
)
def _normalize_role(role_name: str) -> RoleType:

View File

@@ -0,0 +1,108 @@
"""服务层模型任务解析工具。"""
from typing import Any, Dict
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
logger = get_logger("service_task_resolver")
def get_available_models() -> Dict[str, TaskConfig]:
"""获取当前所有可用的模型任务配置。
Returns:
Dict[str, TaskConfig]: 以任务名为键的可用任务配置映射。
"""
try:
models = config_manager.get_model_config().model_task_config
available_models: Dict[str, TaskConfig] = {}
for attr_name in dir(models):
if attr_name.startswith("__"):
continue
try:
attr_value = getattr(models, attr_name)
except Exception as exc:
logger.debug(f"获取模型任务配置属性 {attr_name} 失败: {exc}")
continue
if not callable(attr_value) and isinstance(attr_value, TaskConfig):
available_models[attr_name] = attr_value
return available_models
except Exception as exc:
logger.error(f"获取可用模型配置失败: {exc}")
return {}
def resolve_task_name(task_name: str = "") -> str:
"""根据任务名解析实际可用的模型任务名称。
Args:
task_name: 目标任务名;为空时返回首个可用任务。
Returns:
str: 解析后的模型任务名。
Raises:
RuntimeError: 当前没有任何可用模型配置时抛出。
ValueError: 指定任务名不存在时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_task_name = task_name.strip()
if not normalized_task_name:
return next(iter(models.keys()))
if normalized_task_name not in models:
raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置")
return normalized_task_name
def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str:
"""根据旧版模型配置对象解析任务名。
Args:
model_config: 旧调用方持有的任务配置对象。
preferred_task_name: 候选任务名。
Returns:
str: 解析后的模型任务名。
Raises:
RuntimeError: 当前没有任何可用模型配置时抛出。
ValueError: 无法解析任何可用任务名时抛出。
"""
models = get_available_models()
if not models:
raise RuntimeError("没有可用的模型配置")
normalized_preferred = str(preferred_task_name or "").strip()
if normalized_preferred and normalized_preferred in models:
return normalized_preferred
for task_name, task_cfg in models.items():
if task_cfg is model_config:
return task_name
requested_model_list_raw = getattr(model_config, "model_list", [])
requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()]
if requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if candidate_list == requested_model_list:
return task_name
for requested_model in requested_model_list:
for task_name, task_cfg in models.items():
candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()]
if requested_model in candidate_list:
logger.info(
"旧版 model_config 未命中任务配置,"
f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`"
)
return task_name
if normalized_preferred:
logger.warning(f"无法映射旧版 model_config回退默认任务: preferred={normalized_preferred}")
return resolve_task_name("")