添加文件监视器地基模块,重构模型请求模块使用新版本的配置热重载模块,新增watchfiles依赖
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
||||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
@@ -43,11 +43,44 @@ class LLMRequest:
|
||||
self.task_name = request_type
|
||||
self.model_for_task = model_set
|
||||
self.request_type = request_type
|
||||
self._task_config_name = self._resolve_task_config_name(model_set)
|
||||
self.model_usage: Dict[str, Tuple[int, int, int]] = {
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||
|
||||
def _resolve_task_config_name(self, model_set: TaskConfig) -> Optional[str]:
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
except Exception:
|
||||
return None
|
||||
for attr in dir(model_task_config):
|
||||
if attr.startswith("__"):
|
||||
continue
|
||||
value = getattr(model_task_config, attr, None)
|
||||
if isinstance(value, TaskConfig) and value is model_set:
|
||||
return attr
|
||||
return None
|
||||
|
||||
def _get_latest_task_config(self) -> TaskConfig:
|
||||
if self._task_config_name:
|
||||
try:
|
||||
model_task_config = config_manager.get_model_config().model_task_config
|
||||
value = getattr(model_task_config, self._task_config_name, None)
|
||||
if isinstance(value, TaskConfig):
|
||||
return value
|
||||
except Exception:
|
||||
return self.model_for_task
|
||||
return self.model_for_task
|
||||
|
||||
def _refresh_task_config(self) -> TaskConfig:
|
||||
latest = self._get_latest_task_config()
|
||||
if latest is not self.model_for_task:
|
||||
self.model_for_task = latest
|
||||
if list(self.model_usage.keys()) != latest.model_list:
|
||||
self.model_usage = {model: self.model_usage.get(model, (0, 0, 0)) for model in latest.model_list}
|
||||
return self.model_for_task
|
||||
|
||||
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
|
||||
"""检查请求是否过慢并输出警告日志
|
||||
|
||||
@@ -80,6 +113,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
@@ -123,6 +157,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Optional[str]): 生成的文本描述或None
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
response, _ = await self._execute_request(
|
||||
request_type=RequestType.AUDIO,
|
||||
audio_base64=voice_base64,
|
||||
@@ -148,6 +183,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
def message_factory(client: BaseClient) -> List[Message]:
|
||||
@@ -204,6 +240,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
@@ -246,6 +283,7 @@ class LLMRequest:
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
start_time = time.time()
|
||||
response, model_info = await self._execute_request(
|
||||
request_type=RequestType.EMBEDDING,
|
||||
@@ -269,6 +307,7 @@ class LLMRequest:
|
||||
"""
|
||||
根据配置的策略选择模型:balance(负载均衡)或 random(随机选择)
|
||||
"""
|
||||
self._refresh_task_config()
|
||||
available_models = {
|
||||
model: scores
|
||||
for model, scores in self.model_usage.items()
|
||||
@@ -314,8 +353,8 @@ class LLMRequest:
|
||||
message_list: List[Message],
|
||||
tool_options: list[ToolOption] | None,
|
||||
response_format: RespFormat | None,
|
||||
stream_response_handler: Optional[Callable],
|
||||
async_response_parser: Optional[Callable],
|
||||
stream_response_handler: Optional[Callable[..., Any]],
|
||||
async_response_parser: Optional[Callable[..., Any]],
|
||||
temperature: Optional[float],
|
||||
max_tokens: Optional[int],
|
||||
embedding_input: str | None,
|
||||
@@ -466,8 +505,8 @@ class LLMRequest:
|
||||
message_factory: Optional[Callable[[BaseClient], List[Message]]] = None,
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[Callable] = None,
|
||||
async_response_parser: Optional[Callable] = None,
|
||||
stream_response_handler: Optional[Callable[..., Any]] = None,
|
||||
async_response_parser: Optional[Callable[..., Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
embedding_input: str | None = None,
|
||||
@@ -595,7 +634,7 @@ class TempMethodsLLMUtils:
|
||||
Raises:
|
||||
ValueError: 未找到指定模型
|
||||
"""
|
||||
for model in model_config.models:
|
||||
for model in config_manager.get_model_config().models:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
raise ValueError(f"未找到名为 '{model_name}' 的模型")
|
||||
@@ -614,7 +653,7 @@ class TempMethodsLLMUtils:
|
||||
Raises:
|
||||
ValueError: 未找到指定提供商
|
||||
"""
|
||||
for provider in model_config.api_providers:
|
||||
for provider in config_manager.get_model_config().api_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
||||
Reference in New Issue
Block a user