refactor(llm): enable hot-reload for model config and client runtime

make LLM task config resolution dynamic in LLMRequest
load model clients on demand from latest config
clear client instance cache on config reload
remove stale module-level model_config usage in llm_api
add hot-reload tests for LLM/config watcher flow
This commit is contained in:
DrSmoothl
2026-03-04 21:56:50 +08:00
parent b3a81754e6
commit 2a33fd1121
5 changed files with 174 additions and 13 deletions

View File

@@ -0,0 +1,76 @@
from types import SimpleNamespace
from importlib import util
from pathlib import Path
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
def _load_llm_api_module():
file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py"
spec = util.spec_from_file_location("test_llm_api_module", file_path)
assert spec is not None and spec.loader is not None
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"):
model_task_config = SimpleNamespace(**{attr_name: task_config})
return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[])
def test_llm_request_resolve_task_config_by_signature(monkeypatch):
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils"))
req = LLMRequest(model_set=old_task, request_type="test")
assert req._task_config_name == "utils"
def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch):
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0)
current = {"task": initial_task}
def get_model_config_stub():
return _make_model_config(current["task"], "replyer")
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
req = LLMRequest(model_set=old_task, request_type="test")
assert req._task_config_name == "replyer"
current["task"] = updated_task
req._refresh_task_config()
assert req.model_for_task.model_list == ["gpt-b", "gpt-c"]
assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"]
def test_llm_api_get_available_models_reads_latest_config(monkeypatch):
llm_api = _load_llm_api_module()
first_utils = TaskConfig(model_list=["gpt-a"])
second_utils = TaskConfig(model_list=["gpt-z"])
state = {"task": first_utils}
def get_model_config_stub():
model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"]))
return SimpleNamespace(model_task_config=model_task_config)
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
first = llm_api.get_available_models()
assert first["utils"].model_list == ["gpt-a"]
state["task"] = second_utils
second = llm_api.get_available_models()
assert second["utils"].model_list == ["gpt-z"]

View File

@@ -1,8 +1,28 @@
from src.config.config import model_config
from importlib import import_module
used_client_types = {provider.client_type for provider in model_config.api_providers}
from src.config.config import config_manager
if "openai" in used_client_types:
from . import openai_client # noqa: F401
if "gemini" in used_client_types:
from . import gemini_client # noqa: F401
_CLIENT_MODULE_BY_TYPE: dict[str, str] = {
"openai": ".openai_client",
"gemini": ".gemini_client",
}
_LOADED_CLIENT_TYPES: set[str] = set()
def ensure_client_type_loaded(client_type: str) -> None:
if client_type in _LOADED_CLIENT_TYPES:
return
module_name = _CLIENT_MODULE_BY_TYPE.get(client_type)
if not module_name:
return
import_module(module_name, package=__name__)
_LOADED_CLIENT_TYPES.add(client_type)
def ensure_configured_clients_loaded() -> None:
for provider in config_manager.get_model_config().api_providers:
ensure_client_type_loaded(provider.client_type)
ensure_configured_clients_loaded()

View File

@@ -3,11 +3,15 @@ from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
logger = get_logger("model_client_registry")
@dataclass
class UsageRecord:
@@ -144,6 +148,7 @@ class ClientRegistry:
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str):
"""
@@ -169,6 +174,10 @@ class ClientRegistry:
Returns:
BaseClient: 注册的API客户端实例
"""
from . import ensure_client_type_loaded
ensure_client_type_loaded(api_provider.client_type)
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
@@ -184,5 +193,9 @@ class ClientRegistry:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]
def clear_client_instance_cache(self) -> None:
self.client_instance_cache.clear()
logger.info("检测到配置重载已清空LLM客户端实例缓存")
client_registry = ClientRegistry()

View File

@@ -16,6 +16,7 @@ from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat, RespFormatType
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
from .model_client.base_client import BaseClient, APIResponse, client_registry
from .model_client import ensure_configured_clients_loaded
from .utils import compress_messages, llm_usage_recorder
from .exceptions import (
NetworkConnectionError,
@@ -44,23 +45,62 @@ class LLMRequest:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self._task_config_signature = self._build_task_config_signature(model_set)
self._task_config_name = self._resolve_task_config_name(model_set)
self.model_usage: Dict[str, Tuple[int, int, int]] = {
model: (0, 0, 0) for model in self.model_for_task.model_list
}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
@staticmethod
def _build_task_config_signature(model_set: TaskConfig) -> tuple:
return (
tuple(model_set.model_list),
model_set.selection_strategy,
model_set.temperature,
model_set.max_tokens,
model_set.slow_threshold,
)
@staticmethod
def _iter_task_config_items(model_task_config: Any) -> list[tuple[str, TaskConfig]]:
cls = type(model_task_config)
if hasattr(cls, "model_fields"):
attrs = [name for name in cls.model_fields.keys() if not name.startswith("__")]
else:
attrs = [name for name in dir(model_task_config) if not name.startswith("__")]
items: list[tuple[str, TaskConfig]] = []
for attr in attrs:
value = getattr(model_task_config, attr, None)
if isinstance(value, TaskConfig):
items.append((attr, value))
return items
def _resolve_task_config_by_signature(self, model_set: TaskConfig) -> Optional[str]:
target_signature = self._build_task_config_signature(model_set)
model_task_config = config_manager.get_model_config().model_task_config
return next(
(
attr
for attr, value in self._iter_task_config_items(model_task_config)
if self._build_task_config_signature(value) == target_signature
),
None,
)
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
try:
model_task_config = config_manager.get_model_config().model_task_config
except Exception:
return None
for attr in dir(model_task_config):
if attr.startswith("__"):
continue
value = getattr(model_task_config, attr, None)
if isinstance(value, TaskConfig) and value is model_set:
for attr, value in self._iter_task_config_items(model_task_config):
if value is model_set:
return attr
try:
return self._resolve_task_config_by_signature(model_set)
except Exception:
return None
return None
def _get_latest_task_config(self) -> TaskConfig:
@@ -72,12 +112,22 @@ class LLMRequest:
return value
except Exception:
return self.model_for_task
try:
if resolved_name := self._resolve_task_config_by_signature(self.model_for_task):
self._task_config_name = resolved_name
model_task_config = config_manager.get_model_config().model_task_config
value = getattr(model_task_config, resolved_name, None)
if isinstance(value, TaskConfig):
return value
except Exception:
return self.model_for_task
return self.model_for_task
def _refresh_task_config(self) -> TaskConfig:
latest = self._get_latest_task_config()
if latest is not self.model_for_task:
self.model_for_task = latest
self._task_config_signature = self._build_task_config_signature(latest)
if list(self.model_usage.keys()) != latest.model_list:
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
return self.model_for_task
@@ -417,6 +467,8 @@ class LLMRequest:
if not available_models:
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
ensure_configured_clients_loaded()
strategy = self.model_for_task.selection_strategy.lower()
if strategy == "random":

View File

@@ -13,7 +13,7 @@ from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.payload_content.message import Message
from src.llm_models.model_client.base_client import BaseClient
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
from src.config.config import config_manager
from src.config.model_configs import TaskConfig
logger = get_logger("llm_api")
@@ -31,7 +31,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
"""
try:
# 自动获取所有属性并转换为字典形式
models = model_config.model_task_config
models = config_manager.get_model_config().model_task_config
attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs: