添加文件监视器地基模块,重构模型请求模块使用新版本的配置热重载模块,新增watchfiles依赖

This commit is contained in:
DrSmoothl
2026-02-14 21:17:24 +08:00
parent daad0ba2f0
commit dc36542403
7 changed files with 210 additions and 22 deletions

View File

@@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
import traceback
from src.common.logger import get_logger
from src.config.config import model_config
from src.config.config import config_manager
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
from .payload_content.message import MessageBuilder, Message
from .payload_content.resp_format import RespFormat
@@ -43,11 +43,44 @@ class LLMRequest:
self.task_name = request_type
self.model_for_task = model_set
self.request_type = request_type
self._task_config_name = self._resolve_task_config_name(model_set)
self.model_usage: Dict[str, Tuple[int, int, int]] = {
model: (0, 0, 0) for model in self.model_for_task.model_list
}
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
try:
model_task_config = config_manager.get_model_config().model_task_config
except Exception:
return None
for attr in dir(model_task_config):
if attr.startswith("__"):
continue
value = getattr(model_task_config, attr, None)
if isinstance(value, TaskConfig) and value is model_set:
return attr
return None
def _get_latest_task_config(self) -> TaskConfig:
if self._task_config_name:
try:
model_task_config = config_manager.get_model_config().model_task_config
value = getattr(model_task_config, self._task_config_name, None)
if isinstance(value, TaskConfig):
return value
except Exception:
return self.model_for_task
return self.model_for_task
def _refresh_task_config(self) -> TaskConfig:
latest = self._get_latest_task_config()
if latest is not self.model_for_task:
self.model_for_task = latest
if list(self.model_usage.keys()) != latest.model_list:
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
return self.model_for_task
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志
@@ -80,6 +113,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]:
@@ -123,6 +157,7 @@ class LLMRequest:
Returns:
(Optional[str]): 生成的文本描述或None
"""
self._refresh_task_config()
response, _ = await self._execute_request(
request_type=RequestType.AUDIO,
audio_base64=voice_base64,
@@ -148,6 +183,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
def message_factory(client: BaseClient) -> List[Message]:
@@ -204,6 +240,7 @@ class LLMRequest:
Returns:
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
self._refresh_task_config()
start_time = time.time()
tool_built = self._build_tool_options(tools)
@@ -246,6 +283,7 @@ class LLMRequest:
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
"""
self._refresh_task_config()
start_time = time.time()
response, model_info = await self._execute_request(
request_type=RequestType.EMBEDDING,
@@ -269,6 +307,7 @@ class LLMRequest:
"""
根据配置的策略选择模型balance负载均衡或 random随机选择
"""
self._refresh_task_config()
available_models = {
model: scores
for model, scores in self.model_usage.items()
@@ -314,8 +353,8 @@ class LLMRequest:
message_list: List[Message],
tool_options: list[ToolOption] | None,
response_format: RespFormat | None,
stream_response_handler: Optional[Callable],
async_response_parser: Optional[Callable],
stream_response_handler: Optional[Callable[..., Any]],
async_response_parser: Optional[Callable[..., Any]],
temperature: Optional[float],
max_tokens: Optional[int],
embedding_input: str | None,
@@ -466,8 +505,8 @@ class LLMRequest:
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
tool_options: list[ToolOption] | None = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[Callable] = None,
async_response_parser: Optional[Callable] = None,
stream_response_handler: Optional[Callable[..., Any]] = None,
async_response_parser: Optional[Callable[..., Any]] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
embedding_input: str | None = None,
@@ -595,7 +634,7 @@ class TempMethodsLLMUtils:
Raises:
ValueError: 未找到指定模型
"""
for model in model_config.models:
for model in config_manager.get_model_config().models:
if model.name == model_name:
return model
raise ValueError(f"未找到名为 '{model_name}' 的模型")
@@ -614,7 +653,7 @@ class TempMethodsLLMUtils:
Raises:
ValueError: 未找到指定提供商
"""
for provider in model_config.api_providers:
for provider in config_manager.get_model_config().api_providers:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")