diff --git a/pytests/config_test/test_llm_request_hot_reload.py b/pytests/config_test/test_llm_request_hot_reload.py new file mode 100644 index 00000000..b6a6517b --- /dev/null +++ b/pytests/config_test/test_llm_request_hot_reload.py @@ -0,0 +1,76 @@ +from types import SimpleNamespace +from importlib import util +from pathlib import Path + +from src.config.config import config_manager +from src.config.model_configs import TaskConfig +from src.llm_models.utils_model import LLMRequest + + +def _load_llm_api_module(): + file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py" + spec = util.spec_from_file_location("test_llm_api_module", file_path) + assert spec is not None and spec.loader is not None + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"): + model_task_config = SimpleNamespace(**{attr_name: task_config}) + return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[]) + + +def test_llm_request_resolve_task_config_by_signature(monkeypatch): + old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) + current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) + + monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils")) + + req = LLMRequest(model_set=old_task, request_type="test") + + assert req._task_config_name == "utils" + + +def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch): + old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) + initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0) + updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0) + + current = {"task": initial_task} + + def get_model_config_stub(): + return _make_model_config(current["task"], "replyer") + + monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub) + + req = LLMRequest(model_set=old_task, request_type="test") + assert req._task_config_name == "replyer" + + current["task"] = updated_task + req._refresh_task_config() + + assert req.model_for_task.model_list == ["gpt-b", "gpt-c"] + assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"] + + +def test_llm_api_get_available_models_reads_latest_config(monkeypatch): + llm_api = _load_llm_api_module() + + first_utils = TaskConfig(model_list=["gpt-a"]) + second_utils = TaskConfig(model_list=["gpt-z"]) + + state = {"task": first_utils} + + def get_model_config_stub(): + model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"])) + return SimpleNamespace(model_task_config=model_task_config) + + monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub) + + first = llm_api.get_available_models() + assert first["utils"].model_list == ["gpt-a"] + + state["task"] = second_utils + second = llm_api.get_available_models() + assert second["utils"].model_list == ["gpt-z"] diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py index 80f7e115..945239c0 100644 --- a/src/llm_models/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -1,8 +1,28 @@ -from src.config.config import model_config +from importlib import import_module -used_client_types = {provider.client_type for provider in model_config.api_providers} +from src.config.config import config_manager -if "openai" in used_client_types: - from . import openai_client # noqa: F401 -if "gemini" in used_client_types: - from . import gemini_client # noqa: F401 +_CLIENT_MODULE_BY_TYPE: dict[str, str] = { + "openai": ".openai_client", + "gemini": ".gemini_client", +} + +_LOADED_CLIENT_TYPES: set[str] = set() + + +def ensure_client_type_loaded(client_type: str) -> None: + if client_type in _LOADED_CLIENT_TYPES: + return + module_name = _CLIENT_MODULE_BY_TYPE.get(client_type) + if not module_name: + return + import_module(module_name, package=__name__) + _LOADED_CLIENT_TYPES.add(client_type) + + +def ensure_configured_clients_loaded() -> None: + for provider in config_manager.get_model_config().api_providers: + ensure_client_type_loaded(provider.client_type) + + +ensure_configured_clients_loaded() diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 3b18cb2f..226c725f 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -3,11 +3,15 @@ from dataclasses import dataclass from abc import ABC, abstractmethod from typing import Callable, Any, Optional +from src.common.logger import get_logger +from src.config.config import config_manager from src.config.model_configs import ModelInfo, APIProvider from ..payload_content.message import Message from ..payload_content.resp_format import RespFormat from ..payload_content.tool_option import ToolOption, ToolCall +logger = get_logger("model_client_registry") + @dataclass class UsageRecord: @@ -144,6 +148,7 @@ class ClientRegistry: """APIProvider.type -> BaseClient的映射表""" self.client_instance_cache: dict[str, BaseClient] = {} """APIProvider.name -> BaseClient的映射表""" + config_manager.register_reload_callback(self.clear_client_instance_cache) def register_client_class(self, client_type: str): """ @@ -169,6 +174,10 @@ class ClientRegistry: Returns: BaseClient: 注册的API客户端实例 """ + from . import ensure_client_type_loaded + + ensure_client_type_loaded(api_provider.client_type) + # 如果强制创建新实例,直接创建不使用缓存 if force_new: if client_class := self.client_registry.get(api_provider.client_type): @@ -184,5 +193,9 @@ class ClientRegistry: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") return self.client_instance_cache[api_provider.name] + def clear_client_instance_cache(self) -> None: + self.client_instance_cache.clear() + logger.info("检测到配置重载,已清空LLM客户端实例缓存") + client_registry = ClientRegistry() diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 457b12c6..a3bfb74f 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -16,6 +16,7 @@ from .payload_content.message import MessageBuilder, Message from .payload_content.resp_format import RespFormat, RespFormatType from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType from .model_client.base_client import BaseClient, APIResponse, client_registry +from .model_client import ensure_configured_clients_loaded from .utils import compress_messages, llm_usage_recorder from .exceptions import ( NetworkConnectionError, @@ -44,23 +45,62 @@ class LLMRequest: self.task_name = request_type self.model_for_task = model_set self.request_type = request_type + self._task_config_signature = self._build_task_config_signature(model_set) self._task_config_name = self._resolve_task_config_name(model_set) self.model_usage: Dict[str, Tuple[int, int, int]] = { model: (0, 0, 0) for model in self.model_for_task.model_list } """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整""" + @staticmethod + def _build_task_config_signature(model_set: TaskConfig) -> tuple: + return ( + tuple(model_set.model_list), + model_set.selection_strategy, + model_set.temperature, + model_set.max_tokens, + model_set.slow_threshold, + ) + + @staticmethod + def _iter_task_config_items(model_task_config: Any) -> list[tuple[str, TaskConfig]]: + cls = type(model_task_config) + if hasattr(cls, "model_fields"): + attrs = [name for name in cls.model_fields.keys() if not name.startswith("__")] + else: + attrs = [name for name in dir(model_task_config) if not name.startswith("__")] + + items: list[tuple[str, TaskConfig]] = [] + for attr in attrs: + value = getattr(model_task_config, attr, None) + if isinstance(value, TaskConfig): + items.append((attr, value)) + return items + + def _resolve_task_config_by_signature(self, model_set: TaskConfig) -> Optional[str]: + target_signature = self._build_task_config_signature(model_set) + model_task_config = config_manager.get_model_config().model_task_config + return next( + ( + attr + for attr, value in self._iter_task_config_items(model_task_config) + if self._build_task_config_signature(value) == target_signature + ), + None, + ) + def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]: try: model_task_config = config_manager.get_model_config().model_task_config except Exception: return None - for attr in dir(model_task_config): - if attr.startswith("__"): - continue - value = getattr(model_task_config, attr, None) - if isinstance(value, TaskConfig) and value is model_set: + for attr, value in self._iter_task_config_items(model_task_config): + if value is model_set: return attr + try: + return self._resolve_task_config_by_signature(model_set) + except Exception: + return None return None def _get_latest_task_config(self) -> TaskConfig: @@ -72,12 +112,22 @@ class LLMRequest: return value except Exception: return self.model_for_task + try: + if resolved_name := self._resolve_task_config_by_signature(self.model_for_task): + self._task_config_name = resolved_name + model_task_config = config_manager.get_model_config().model_task_config + value = getattr(model_task_config, resolved_name, None) + if isinstance(value, TaskConfig): + return value + except Exception: + return self.model_for_task return self.model_for_task def _refresh_task_config(self) -> TaskConfig: latest = self._get_latest_task_config() if latest is not self.model_for_task: self.model_for_task = latest + self._task_config_signature = self._build_task_config_signature(latest) if list(self.model_usage.keys()) != latest.model_list: self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list} return self.model_for_task @@ -417,6 +467,8 @@ class LLMRequest: if not available_models: raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。") + ensure_configured_clients_loaded() + strategy = self.model_for_task.selection_strategy.lower() if strategy == "random": diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index c201c1b7..f35b1102 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -13,7 +13,7 @@ from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.payload_content.message import Message from src.llm_models.model_client.base_client import BaseClient from src.llm_models.utils_model import LLMRequest -from src.config.config import model_config +from src.config.config import config_manager from src.config.model_configs import TaskConfig logger = get_logger("llm_api") @@ -31,7 +31,7 @@ def get_available_models() -> Dict[str, TaskConfig]: """ try: # 自动获取所有属性并转换为字典形式 - models = model_config.model_task_config + models = config_manager.get_model_config().model_task_config attrs = dir(models) rets: Dict[str, TaskConfig] = {} for attr in attrs: