diff --git a/pytests/test_plugin_config_runtime.py b/pytests/test_plugin_config_runtime.py index 51eb4a80..160936a1 100644 --- a/pytests/test_plugin_config_runtime.py +++ b/pytests/test_plugin_config_runtime.py @@ -169,6 +169,33 @@ def test_runner_apply_plugin_config_generates_config_file(tmp_path: Path) -> Non assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}} +def test_runner_apply_plugin_config_preserves_existing_comments(tmp_path: Path) -> None: + """Runner 补齐配置时应尽量保留现有 config.toml 注释。""" + + plugin = _DemoConfigPlugin() + runner = PluginRunner( + host_address="ipc://unused", + session_token="session-token", + plugin_dirs=[], + ) + meta = SimpleNamespace(plugin_id="demo.plugin", plugin_dir=str(tmp_path), instance=plugin) + config_path = tmp_path / "config.toml" + config_path.write_text( + '# 插件配置头注释\n[plugin]\nenabled = false # 启用开关注释\n', + encoding="utf-8", + ) + + runner._apply_plugin_config(cast(Any, meta)) + + config_text = config_path.read_text(encoding="utf-8") + assert "# 插件配置头注释" in config_text + assert "# 启用开关注释" in config_text + + with config_path.open("rb") as handle: + saved_config = tomllib.load(handle) + assert saved_config == {"plugin": {"enabled": False, "retry_count": 3}} + + def test_component_query_service_returns_plugin_config_schema(monkeypatch: Any) -> None: """组件查询服务应支持按插件 ID 返回配置 Schema。""" diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 026c72ee..b56d8813 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -1,31 +1,32 @@ -from dataclasses import dataclass -import json -import os -import math -import asyncio from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Dict, List, Tuple +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple +import json +import math +import os + +import faiss import numpy as np import pandas as pd -# import tqdm -import faiss - -from .utils.hash import get_sha256 -from .global_logger import logger from rich.traceback import install from rich.progress import ( - Progress, BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TextColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, - MofNCompleteColumn, - SpinnerColumn, - TextColumn, ) + from src.config.config import global_config +from src.services.embedding_service import EmbeddingServiceClient + +from .global_logger import logger +from .utils.hash import get_sha256 install(extra_lines=3) @@ -133,19 +134,20 @@ class EmbeddingStore: return [f"{namespace}-{get_sha256(t)}" for t in texts] def _get_embedding(self, s: str) -> List[float]: - """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" - # 创建新的事件循环并在完成后立即关闭 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + """以同步方式获取单条字符串的嵌入向量。 + Args: + s: 待编码的文本内容。 + + Returns: + List[float]: 嵌入向量;失败时返回空列表。 + """ try: - # 创建新的服务层实例 - from src.services.llm_service import LLMServiceClient - - llm = LLMServiceClient(task_name="embedding", request_type="embedding") - - # 使用新的事件循环运行异步方法 - embedding_result = loop.run_until_complete(llm.embed_text(s)) + embedding_client = EmbeddingServiceClient( + task_name="embedding", + request_type="embedding", + ) + embedding_result = embedding_client.embed_text_sync(s) embedding = embedding_result.embedding if embedding and len(embedding) > 0: @@ -157,17 +159,15 @@ class EmbeddingStore: except Exception as e: logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") return [] - finally: - # 确保事件循环被正确关闭 - try: - loop.close() - except Exception: - pass def _get_embeddings_batch_threaded( - self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None + self, + strs: List[str], + chunk_size: int = 10, + max_workers: int = 10, + progress_callback: Callable[[int], None] | None = None, ) -> List[Tuple[str, List[float]]]: - """使用多线程批量获取嵌入向量 + """使用多线程批量获取嵌入向量。 Args: strs: 要获取嵌入的字符串列表 @@ -190,53 +190,42 @@ class EmbeddingStore: # 结果存储,使用字典按索引存储以保证顺序 results = {} - def process_chunk(chunk_data): - """处理单个数据块的函数""" - start_idx, chunk_strs = chunk_data - chunk_results = [] + def process_chunk(chunk_data: Tuple[int, List[str]]) -> List[Tuple[int, str, List[float]]]: + """处理单个数据块。 - # 为每个线程创建独立的服务层实例 - from src.services.llm_service import LLMServiceClient + Args: + chunk_data: 数据块起始索引与字符串列表。 + + Returns: + List[Tuple[int, str, List[float]]]: 带原始索引的处理结果。 + """ + start_idx, chunk_strs = chunk_data + chunk_results: List[Tuple[int, str, List[float]]] = [] try: - # 创建线程专用的服务层实例 - llm = LLMServiceClient(task_name="embedding", request_type="embedding") - - for i, s in enumerate(chunk_strs): - try: - # 在线程中创建独立的事件循环 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - embedding_result = loop.run_until_complete(llm.embed_text(s)) - embedding = embedding_result.embedding - finally: - loop.close() - - if embedding and len(embedding) > 0: - chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 - else: - logger.error(f"获取嵌入失败: {s}") - chunk_results.append((start_idx + i, s, [])) - - # 每完成一个嵌入立即更新进度 - if progress_callback: - progress_callback(1) - - except Exception as e: - logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + embedding_client = EmbeddingServiceClient( + task_name="embedding", + request_type="embedding", + ) + embedding_results = embedding_client.embed_texts_sync( + chunk_strs, + max_concurrent=1, + ) + for i, (s, embedding_result) in enumerate(zip(chunk_strs, embedding_results, strict=False)): + embedding = embedding_result.embedding + if embedding and len(embedding) > 0: + chunk_results.append((start_idx + i, s, embedding)) + else: + logger.error(f"获取嵌入失败: {s}") chunk_results.append((start_idx + i, s, [])) - # 即使失败也要更新进度 - if progress_callback: - progress_callback(1) + if progress_callback: + progress_callback(1) except Exception as e: - logger.error(f"创建LLM实例失败: {e}") - # 如果创建LLM实例失败,返回空结果 + logger.error(f"创建 EmbeddingService 实例失败: {e}") for i, s in enumerate(chunk_strs): chunk_results.append((start_idx + i, s, [])) - # 即使失败也要更新进度 if progress_callback: progress_callback(1) diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index aa14e790..a1739db8 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -14,8 +14,8 @@ from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger from src.config.config import global_config -from src.services.llm_service import LLMServiceClient from src.person_info.person_info import Person +from src.services.embedding_service import EmbeddingServiceClient from .typo_generator import ChineseTypoGenerator @@ -233,12 +233,19 @@ def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, fl return is_mentioned, is_at, reply_probability -async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: - """获取文本的embedding向量""" - # 每次都创建新的服务层实例以避免事件循环冲突 - llm = LLMServiceClient(task_name="embedding", request_type=request_type) +async def get_embedding(text: str, request_type: str = "embedding") -> Optional[List[float]]: + """获取文本的嵌入向量。 + + Args: + text: 待编码的文本内容。 + request_type: 当前请求的业务类型标识。 + + Returns: + Optional[List[float]]: 成功时返回嵌入向量,失败时返回 `None`。 + """ + embedding_client = EmbeddingServiceClient(task_name="embedding", request_type=request_type) try: - embedding_result = await llm.embed_text(text) + embedding_result = await embedding_client.embed_text(text) embedding = embedding_result.embedding except Exception as e: logger.error(f"获取embedding失败: {str(e)}") diff --git a/src/common/data_models/embedding_service_data_models.py b/src/common/data_models/embedding_service_data_models.py new file mode 100644 index 00000000..9ee2a158 --- /dev/null +++ b/src/common/data_models/embedding_service_data_models.py @@ -0,0 +1,19 @@ +"""Embedding 服务层共享数据模型。""" + +from dataclasses import dataclass, field +from typing import List + +from src.common.data_models import BaseDataModel + + +@dataclass(slots=True) +class EmbeddingResult(BaseDataModel): + """Embedding 服务层统一响应对象。""" + + embedding: List[float] = field(default_factory=list) + model_name: str = field(default_factory=str) + + +__all__ = [ + "EmbeddingResult", +] diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index 7a76de0f..157bdecf 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -458,6 +458,62 @@ class PluginRunner: logger.warning(f"插件配置归一化失败,将回退为原始配置: {exc}") return normalized_config, False + @staticmethod + def _merge_plugin_config_document(target: Any, source: Any) -> None: + """递归更新现有 TOML 文档,尽量保留原注释与格式。 + + 这里采用“更新已有键、补充缺失键”的策略,而不是直接整体重写, + 这样插件启动时因补齐默认配置触发落盘时,可以尽量保留用户手写的注释。 + + Args: + target: 现有的 TOML 文档或表对象。 + source: 最新的配置字典。 + """ + + if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict): + return + + for key, value in source.items(): + if key in target: + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, dict): + PluginRunner._merge_plugin_config_document(target_value, value) + else: + try: + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + target[key] = value + else: + try: + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + target[key] = value + + @staticmethod + def _has_extra_config_keys(existing_config: Any, latest_config: Any) -> bool: + """判断现有配置中是否包含新配置不存在的键。 + + 如果插件归一化后的结果删除了某些旧键,就需要回退到完整重写, + 否则仅做增量合并会把旧键残留在文件里。 + + Args: + existing_config: 现有配置字典。 + latest_config: 最新配置字典。 + + Returns: + bool: 是否存在需要通过整文件重写才能删除的旧键。 + """ + + if not isinstance(existing_config, dict) or not isinstance(latest_config, dict): + return False + + for key, existing_value in existing_config.items(): + if key not in latest_config: + return True + if PluginRunner._has_extra_config_keys(existing_value, latest_config[key]): + return True + return False + @staticmethod def _is_plugin_enabled(config_data: Optional[Mapping[str, Any]]) -> bool: """根据配置内容判断插件是否应被视为启用。 @@ -496,6 +552,19 @@ class PluginRunner: config_path = Path(plugin_dir) / "config.toml" config_path.parent.mkdir(parents=True, exist_ok=True) + if config_path.exists(): + try: + with config_path.open("r", encoding="utf-8") as handle: + existing_document = tomlkit.load(handle) + existing_config = existing_document.unwrap() + if not PluginRunner._has_extra_config_keys(existing_config, config_data): + PluginRunner._merge_plugin_config_document(existing_document, config_data) + with config_path.open("w", encoding="utf-8") as handle: + handle.write(tomlkit.dumps(existing_document)) + return + except Exception as exc: + logger.warning(f"保留插件配置注释失败,将回退为整文件重写: {config_path}: {exc}") + with config_path.open("w", encoding="utf-8") as handle: handle.write(tomlkit.dumps(config_data)) diff --git a/src/services/embedding_service.py b/src/services/embedding_service.py new file mode 100644 index 00000000..1a806cd0 --- /dev/null +++ b/src/services/embedding_service.py @@ -0,0 +1,160 @@ +"""Embedding 服务层。 + +该模块负责在宿主侧收口统一的文本嵌入请求,并将其转发到 +`src.llm_models` 中的底层嵌入调度器。 +""" + +from __future__ import annotations + +from typing import Any, Coroutine, List, TypeVar + +import asyncio + +from src.common.data_models.embedding_service_data_models import EmbeddingResult +from src.common.logger import get_logger +from src.llm_models.utils_model import LLMOrchestrator +from src.services.service_task_resolver import resolve_task_name + +logger = get_logger("embedding_service") + +_CoroutineReturnT = TypeVar("_CoroutineReturnT") + + +class EmbeddingServiceClient: + """面向上层模块的 Embedding 服务对象式门面。""" + + def __init__(self, task_name: str = "embedding", request_type: str = "") -> None: + """初始化 Embedding 服务门面。 + + Args: + task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 + request_type: 当前请求的业务类型标识。 + """ + self.task_name = resolve_task_name(task_name) + self.request_type = request_type + self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type) + + async def embed_text(self, embedding_input: str) -> EmbeddingResult: + """生成单条文本的嵌入向量。 + + Args: + embedding_input: 待编码的文本内容。 + + Returns: + EmbeddingResult: 统一嵌入结果对象。 + """ + raw_result = await self._orchestrator.get_embedding(embedding_input) + return EmbeddingResult( + embedding=list(raw_result.embedding), + model_name=raw_result.model_name, + ) + + async def embed_texts( + self, + embedding_inputs: List[str], + max_concurrent: int | None = None, + ) -> List[EmbeddingResult]: + """批量生成文本嵌入向量。 + + Args: + embedding_inputs: 待编码的文本列表。 + max_concurrent: 最大并发数;未提供时按串行执行。 + + Returns: + List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。 + """ + if not embedding_inputs: + return [] + + safe_max_concurrent = max(1, int(max_concurrent or 1)) + if safe_max_concurrent == 1: + results: List[EmbeddingResult] = [] + for embedding_input in embedding_inputs: + results.append(await self.embed_text(embedding_input)) + return results + + semaphore = asyncio.Semaphore(safe_max_concurrent) + + async def _embed_one(index: int, embedding_input: str) -> tuple[int, EmbeddingResult]: + """执行单条嵌入并保留原始顺序索引。 + + Args: + index: 原始输入索引。 + embedding_input: 待编码的文本内容。 + + Returns: + tuple[int, EmbeddingResult]: 输入索引与对应嵌入结果。 + """ + async with semaphore: + result = await self.embed_text(embedding_input) + return index, result + + ordered_results = await asyncio.gather( + *[_embed_one(index, embedding_input) for index, embedding_input in enumerate(embedding_inputs)] + ) + ordered_results.sort(key=lambda item: item[0]) + return [result for _, result in ordered_results] + + def embed_text_sync(self, embedding_input: str) -> EmbeddingResult: + """以同步方式生成单条文本的嵌入向量。 + + Args: + embedding_input: 待编码的文本内容。 + + Returns: + EmbeddingResult: 统一嵌入结果对象。 + """ + return self._run_coroutine_sync(self.embed_text(embedding_input)) + + def embed_texts_sync( + self, + embedding_inputs: List[str], + max_concurrent: int | None = None, + ) -> List[EmbeddingResult]: + """以同步方式批量生成文本嵌入向量。 + + Args: + embedding_inputs: 待编码的文本列表。 + max_concurrent: 最大并发数;未提供时按串行执行。 + + Returns: + List[EmbeddingResult]: 与输入顺序一致的嵌入结果列表。 + """ + return self._run_coroutine_sync( + self.embed_texts( + embedding_inputs, + max_concurrent=max_concurrent, + ) + ) + + @staticmethod + def _run_coroutine_sync(coroutine: Coroutine[Any, Any, _CoroutineReturnT]) -> _CoroutineReturnT: + """在独立事件循环中执行协程。 + + Args: + coroutine: 需要同步执行的协程对象。 + + Returns: + _CoroutineReturnT: 协程返回值。 + + Raises: + RuntimeError: 当前线程已有运行中的事件循环时抛出。 + """ + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError("当前线程存在运行中的事件循环,请改用异步 Embedding 接口") + + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + return loop.run_until_complete(coroutine) + finally: + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + except Exception as exc: + logger.debug(f"关闭 EmbeddingService 临时异步生成器失败: {exc}") + asyncio.set_event_loop(None) + loop.close() diff --git a/src/services/llm_service.py b/src/services/llm_service.py index 9ddac4d0..4b9972f3 100644 --- a/src/services/llm_service.py +++ b/src/services/llm_service.py @@ -8,9 +8,9 @@ from typing import Any, Dict, List, Tuple import json +from src.common.data_models.embedding_service_data_models import EmbeddingResult from src.common.data_models.llm_service_data_models import ( LLMAudioTranscriptionResult, - LLMEmbeddingResult, LLMGenerationOptions, LLMImageOptions, LLMResponseResult, @@ -21,15 +21,20 @@ from src.common.data_models.llm_service_data_models import ( PromptMessage, ) from src.common.logger import get_logger -from src.config.config import config_manager -from src.config.model_configs import TaskConfig from src.llm_models.model_client.base_client import BaseClient from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.utils_model import LLMOrchestrator +from src.services.embedding_service import EmbeddingServiceClient +from src.services.service_task_resolver import ( + get_available_models as _get_available_models, + resolve_task_name as _resolve_task_name, + resolve_task_name_from_model_config as _resolve_task_name_from_model_config, +) logger = get_logger("llm_service") + class LLMServiceClient: """面向上层模块的 LLM 服务对象式门面。 @@ -38,7 +43,7 @@ class LLMServiceClient: - `generate_response_with_messages` - `generate_response_for_image` - `transcribe_audio` - - `embed_text` + - `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`) """ def __init__(self, task_name: str, request_type: str = "") -> None: @@ -48,7 +53,7 @@ class LLMServiceClient: task_name: 任务配置名称,对应 `model_task_config` 下的字段名。 request_type: 当前请求的业务类型标识。 """ - self.task_name = resolve_task_name(task_name) + self.task_name = _resolve_task_name(task_name) self.request_type = request_type self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type) @@ -169,41 +174,29 @@ class LLMServiceClient: """ return await self._orchestrator.generate_response_for_voice(voice_base64) - async def embed_text(self, embedding_input: str) -> LLMEmbeddingResult: - """生成文本嵌入向量。 + async def embed_text(self, embedding_input: str) -> EmbeddingResult: + """兼容旧调用的文本嵌入入口。 Args: embedding_input: 待编码的文本。 Returns: - LLMEmbeddingResult: 向量生成结果对象。 + EmbeddingResult: 向量生成结果对象。 """ - return await self._orchestrator.get_embedding(embedding_input) + embedding_client = EmbeddingServiceClient( + task_name=self.task_name, + request_type=self.request_type, + ) + return await embedding_client.embed_text(embedding_input) -def get_available_models() -> Dict[str, TaskConfig]: +def get_available_models() -> Dict[str, Any]: """获取所有可用模型配置。 Returns: - Dict[str, TaskConfig]: 以模型任务名为键的配置映射。 + Dict[str, Any]: 以模型任务名为键的配置映射。 """ - try: - models = config_manager.get_model_config().model_task_config - available_models: Dict[str, TaskConfig] = {} - for attr_name in dir(models): - if attr_name.startswith("__"): - continue - try: - attr_value = getattr(models, attr_name) - except Exception as exc: - logger.debug(f"[LLMService] 获取属性 {attr_name} 失败: {exc}") - continue - if not callable(attr_value) and isinstance(attr_value, TaskConfig): - available_models[attr_name] = attr_value - return available_models - except Exception as exc: - logger.error(f"[LLMService] 获取可用模型失败: {exc}") - return {} + return _get_available_models() def resolve_task_name(task_name: str = "") -> str: @@ -214,75 +207,24 @@ def resolve_task_name(task_name: str = "") -> str: Returns: str: 解析得到的任务配置名。 - - Raises: - RuntimeError: 当前没有任何可用模型配置。 - ValueError: 指定名称不存在时抛出。 """ - models = get_available_models() - if not models: - raise RuntimeError("没有可用的模型配置") - normalized_task_name = task_name.strip() - if not normalized_task_name: - return next(iter(models.keys())) - if normalized_task_name not in models: - raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置") - return normalized_task_name + return _resolve_task_name(task_name) def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str: """根据旧版 `TaskConfig` 风格参数解析可用任务名。 - 该方法用于兼容仍以 `model_config` 传参的调用方: - 1. 优先使用显式给出的 `preferred_task_name`; - 2. 其次匹配对象同一性; - 3. 再尝试按 `model_list` 精确匹配; - 4. 最后按 `model_list` 中首个命中的模型进行近似映射。 - Args: model_config: 旧调用方持有的任务配置对象。 preferred_task_name: 候选任务名(可选)。 Returns: str: 可用于 `LLMServiceRequest.task_name` 的任务名。 - - Raises: - RuntimeError: 当前没有可用模型配置。 - ValueError: 无法解析任何可用任务名时抛出。 """ - models = get_available_models() - if not models: - raise RuntimeError("没有可用的模型配置") - - normalized_preferred = str(preferred_task_name or "").strip() - if normalized_preferred and normalized_preferred in models: - return normalized_preferred - - for task_name, task_cfg in models.items(): - if task_cfg is model_config: - return task_name - - requested_model_list_raw = getattr(model_config, "model_list", []) - requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()] - if requested_model_list: - for task_name, task_cfg in models.items(): - candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()] - if candidate_list == requested_model_list: - return task_name - - for requested_model in requested_model_list: - for task_name, task_cfg in models.items(): - candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()] - if requested_model in candidate_list: - logger.info( - "[LLMService] 旧版 model_config 未命中任务配置," - f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`" - ) - return task_name - - if normalized_preferred: - logger.warning(f"[LLMService] 无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}") - return resolve_task_name("") + return _resolve_task_name_from_model_config( + model_config=model_config, + preferred_task_name=preferred_task_name, + ) def _normalize_role(role_name: str) -> RoleType: diff --git a/src/services/service_task_resolver.py b/src/services/service_task_resolver.py new file mode 100644 index 00000000..b8c129b4 --- /dev/null +++ b/src/services/service_task_resolver.py @@ -0,0 +1,108 @@ +"""服务层模型任务解析工具。""" + +from typing import Any, Dict + +from src.common.logger import get_logger +from src.config.config import config_manager +from src.config.model_configs import TaskConfig + +logger = get_logger("service_task_resolver") + + +def get_available_models() -> Dict[str, TaskConfig]: + """获取当前所有可用的模型任务配置。 + + Returns: + Dict[str, TaskConfig]: 以任务名为键的可用任务配置映射。 + """ + try: + models = config_manager.get_model_config().model_task_config + available_models: Dict[str, TaskConfig] = {} + for attr_name in dir(models): + if attr_name.startswith("__"): + continue + try: + attr_value = getattr(models, attr_name) + except Exception as exc: + logger.debug(f"获取模型任务配置属性 {attr_name} 失败: {exc}") + continue + if not callable(attr_value) and isinstance(attr_value, TaskConfig): + available_models[attr_name] = attr_value + return available_models + except Exception as exc: + logger.error(f"获取可用模型配置失败: {exc}") + return {} + + +def resolve_task_name(task_name: str = "") -> str: + """根据任务名解析实际可用的模型任务名称。 + + Args: + task_name: 目标任务名;为空时返回首个可用任务。 + + Returns: + str: 解析后的模型任务名。 + + Raises: + RuntimeError: 当前没有任何可用模型配置时抛出。 + ValueError: 指定任务名不存在时抛出。 + """ + models = get_available_models() + if not models: + raise RuntimeError("没有可用的模型配置") + + normalized_task_name = task_name.strip() + if not normalized_task_name: + return next(iter(models.keys())) + if normalized_task_name not in models: + raise ValueError(f"未找到名为 `{normalized_task_name}` 的模型配置") + return normalized_task_name + + +def resolve_task_name_from_model_config(model_config: Any, preferred_task_name: str = "") -> str: + """根据旧版模型配置对象解析任务名。 + + Args: + model_config: 旧调用方持有的任务配置对象。 + preferred_task_name: 候选任务名。 + + Returns: + str: 解析后的模型任务名。 + + Raises: + RuntimeError: 当前没有任何可用模型配置时抛出。 + ValueError: 无法解析任何可用任务名时抛出。 + """ + models = get_available_models() + if not models: + raise RuntimeError("没有可用的模型配置") + + normalized_preferred = str(preferred_task_name or "").strip() + if normalized_preferred and normalized_preferred in models: + return normalized_preferred + + for task_name, task_cfg in models.items(): + if task_cfg is model_config: + return task_name + + requested_model_list_raw = getattr(model_config, "model_list", []) + requested_model_list = [str(item).strip() for item in (requested_model_list_raw or []) if str(item).strip()] + if requested_model_list: + for task_name, task_cfg in models.items(): + candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()] + if candidate_list == requested_model_list: + return task_name + + for requested_model in requested_model_list: + for task_name, task_cfg in models.items(): + candidate_list = [str(item).strip() for item in getattr(task_cfg, "model_list", []) if str(item).strip()] + if requested_model in candidate_list: + logger.info( + "旧版 model_config 未命中任务配置," + f"按模型 `{requested_model}` 近似映射到任务 `{task_name}`" + ) + return task_name + + if normalized_preferred: + logger.warning(f"无法映射旧版 model_config,回退默认任务: preferred={normalized_preferred}") + return resolve_task_name("")