refactor(llm): enable hot-reload for model config and client runtime
make LLM task config resolution dynamic in LLMRequest load model clients on demand from latest config clear client instance cache on config reload remove stale module-level model_config usage in llm_api add hot-reload tests for LLM/config watcher flow
This commit is contained in:
76
pytests/config_test/test_llm_request_hot_reload.py
Normal file
76
pytests/config_test/test_llm_request_hot_reload.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from types import SimpleNamespace
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
def _load_llm_api_module():
|
||||
file_path = Path(__file__).parent.parent.parent / "src" / "plugin_system" / "apis" / "llm_api.py"
|
||||
spec = util.spec_from_file_location("test_llm_api_module", file_path)
|
||||
assert spec is not None and spec.loader is not None
|
||||
module = util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _make_model_config(task_config: TaskConfig, attr_name: str = "utils"):
|
||||
model_task_config = SimpleNamespace(**{attr_name: task_config})
|
||||
return SimpleNamespace(model_task_config=model_task_config, models=[], api_providers=[])
|
||||
|
||||
|
||||
def test_llm_request_resolve_task_config_by_signature(monkeypatch):
|
||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
current_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", lambda: _make_model_config(current_task, "utils"))
|
||||
|
||||
req = LLMRequest(model_set=old_task, request_type="test")
|
||||
|
||||
assert req._task_config_name == "utils"
|
||||
|
||||
|
||||
def test_llm_request_refresh_task_config_updates_runtime_state(monkeypatch):
|
||||
old_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
initial_task = TaskConfig(model_list=["gpt-a"], max_tokens=512, temperature=0.3, slow_threshold=15.0)
|
||||
updated_task = TaskConfig(model_list=["gpt-b", "gpt-c"], max_tokens=1024, temperature=0.5, slow_threshold=20.0)
|
||||
|
||||
current = {"task": initial_task}
|
||||
|
||||
def get_model_config_stub():
|
||||
return _make_model_config(current["task"], "replyer")
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
||||
|
||||
req = LLMRequest(model_set=old_task, request_type="test")
|
||||
assert req._task_config_name == "replyer"
|
||||
|
||||
current["task"] = updated_task
|
||||
req._refresh_task_config()
|
||||
|
||||
assert req.model_for_task.model_list == ["gpt-b", "gpt-c"]
|
||||
assert list(req.model_usage.keys()) == ["gpt-b", "gpt-c"]
|
||||
|
||||
|
||||
def test_llm_api_get_available_models_reads_latest_config(monkeypatch):
|
||||
llm_api = _load_llm_api_module()
|
||||
|
||||
first_utils = TaskConfig(model_list=["gpt-a"])
|
||||
second_utils = TaskConfig(model_list=["gpt-z"])
|
||||
|
||||
state = {"task": first_utils}
|
||||
|
||||
def get_model_config_stub():
|
||||
model_task_config = SimpleNamespace(utils=state["task"], planner=TaskConfig(model_list=["gpt-p"]))
|
||||
return SimpleNamespace(model_task_config=model_task_config)
|
||||
|
||||
monkeypatch.setattr(config_manager, "get_model_config", get_model_config_stub)
|
||||
|
||||
first = llm_api.get_available_models()
|
||||
assert first["utils"].model_list == ["gpt-a"]
|
||||
|
||||
state["task"] = second_utils
|
||||
second = llm_api.get_available_models()
|
||||
assert second["utils"].model_list == ["gpt-z"]
|
||||
@@ -1,8 +1,28 @@
|
||||
from src.config.config import model_config
|
||||
from importlib import import_module
|
||||
|
||||
used_client_types = {provider.client_type for provider in model_config.api_providers}
|
||||
from src.config.config import config_manager
|
||||
|
||||
if "openai" in used_client_types:
|
||||
from . import openai_client # noqa: F401
|
||||
if "gemini" in used_client_types:
|
||||
from . import gemini_client # noqa: F401
|
||||
_CLIENT_MODULE_BY_TYPE: dict[str, str] = {
|
||||
"openai": ".openai_client",
|
||||
"gemini": ".gemini_client",
|
||||
}
|
||||
|
||||
_LOADED_CLIENT_TYPES: set[str] = set()
|
||||
|
||||
|
||||
def ensure_client_type_loaded(client_type: str) -> None:
|
||||
if client_type in _LOADED_CLIENT_TYPES:
|
||||
return
|
||||
module_name = _CLIENT_MODULE_BY_TYPE.get(client_type)
|
||||
if not module_name:
|
||||
return
|
||||
import_module(module_name, package=__name__)
|
||||
_LOADED_CLIENT_TYPES.add(client_type)
|
||||
|
||||
|
||||
def ensure_configured_clients_loaded() -> None:
|
||||
for provider in config_manager.get_model_config().api_providers:
|
||||
ensure_client_type_loaded(provider.client_type)
|
||||
|
||||
|
||||
ensure_configured_clients_loaded()
|
||||
|
||||
@@ -3,11 +3,15 @@ from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
|
||||
logger = get_logger("model_client_registry")
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageRecord:
|
||||
@@ -144,6 +148,7 @@ class ClientRegistry:
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
@@ -169,6 +174,10 @@ class ClientRegistry:
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
"""
|
||||
from . import ensure_client_type_loaded
|
||||
|
||||
ensure_client_type_loaded(api_provider.client_type)
|
||||
|
||||
# 如果强制创建新实例,直接创建不使用缓存
|
||||
if force_new:
|
||||
if client_class := self.client_registry.get(api_provider.client_type):
|
||||
@@ -184,5 +193,9 @@ class ClientRegistry:
|
||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
def clear_client_instance_cache(self) -> None:
|
||||
self.client_instance_cache.clear()
|
||||
logger.info("检测到配置重载,已清空LLM客户端实例缓存")
|
||||
|
||||
|
||||
client_registry = ClientRegistry()
|
||||
|
||||
@@ -16,6 +16,7 @@ from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat, RespFormatType
|
||||
from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
|
||||
from .model_client.base_client import BaseClient, APIResponse, client_registry
|
||||
from .model_client import ensure_configured_clients_loaded
|
||||
from .utils import compress_messages, llm_usage_recorder
|
||||
from .exceptions import (
|
||||
NetworkConnectionError,
|
||||
@@ -44,23 +45,62 @@ class LLMRequest:
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self._task_config_signature = self._build_task_config_signature(model_set)
|
||||
self._task_config_name = self._resolve_task_config_name(model_set)
|
||||
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||
|
||||
@staticmethod
|
||||
def _build_task_config_signature(model_set: TaskConfig) -> tuple:
|
||||
return (
|
||||
tuple(model_set.model_list),
|
||||
model_set.selection_strategy,
|
||||
model_set.temperature,
|
||||
model_set.max_tokens,
|
||||
model_set.slow_threshold,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _iter_task_config_items(model_task_config: Any) -> list[tuple[str, TaskConfig]]:
|
||||
cls = type(model_task_config)
|
||||
if hasattr(cls, "model_fields"):
|
||||
attrs = [name for name in cls.model_fields.keys() if not name.startswith("__")]
|
||||
else:
|
||||
attrs = [name for name in dir(model_task_config) if not name.startswith("__")]
|
||||
|
||||
items: list[tuple[str, TaskConfig]] = []
|
||||
for attr in attrs:
|
||||
value = getattr(model_task_config, attr, None)
|
||||
if isinstance(value, TaskConfig):
|
||||
items.append((attr, value))
|
||||
return items
|
||||
|
||||
def _resolve_task_config_by_signature(self, model_set: TaskConfig) -> Optional[str]:
|
||||
target_signature = self._build_task_config_signature(model_set)
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
return next(
|
||||
(
|
||||
attr
|
||||
for attr, value in self._iter_task_config_items(model_task_config)
|
||||
if self._build_task_config_signature(value) == target_signature
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
except Exception:
|
||||
return None
|
||||
for attr in dir(model_task_config):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
value = getattr(model_task_config, attr, None)
|
||||
if isinstance(value, TaskConfig) and value is model_set:
|
||||
for attr, value in self._iter_task_config_items(model_task_config):
|
||||
if value is model_set:
|
||||
return attr
|
||||
try:
|
||||
return self._resolve_task_config_by_signature(model_set)
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _get_latest_task_config(self) -> TaskConfig:
|
||||
@@ -72,12 +112,22 @@ class LLMRequest:
|
||||
return value
|
||||
except Exception:
|
||||
return self.model_for_task
|
||||
try:
|
||||
if resolved_name := self._resolve_task_config_by_signature(self.model_for_task):
|
||||
self._task_config_name = resolved_name
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
value = getattr(model_task_config, resolved_name, None)
|
||||
if isinstance(value, TaskConfig):
|
||||
return value
|
||||
except Exception:
|
||||
return self.model_for_task
|
||||
return self.model_for_task
|
||||
|
||||
def _refresh_task_config(self) -> TaskConfig:
|
||||
latest = self._get_latest_task_config()
|
||||
if latest is not self.model_for_task:
|
||||
self.model_for_task = latest
|
||||
self._task_config_signature = self._build_task_config_signature(latest)
|
||||
if list(self.model_usage.keys()) != latest.model_list:
|
||||
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
|
||||
return self.model_for_task
|
||||
@@ -417,6 +467,8 @@ class LLMRequest:
|
||||
if not available_models:
|
||||
raise RuntimeError("没有可用的模型可供选择。所有模型均已尝试失败。")
|
||||
|
||||
ensure_configured_clients_loaded()
|
||||
|
||||
strategy = self.model_for_task.selection_strategy.lower()
|
||||
|
||||
if strategy == "random":
|
||||
|
||||
@@ -13,7 +13,7 @@ from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.model_client.base_client import BaseClient
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
@@ -31,7 +31,7 @@ def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""
|
||||
try:
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
models = model_config.model_task_config
|
||||
models = config_manager.get_model_config().model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
|
||||
Reference in New Issue
Block a user