440 lines
12 KiB
Python
440 lines
12 KiB
Python
from enum import Enum
|
||
from typing import Any
|
||
|
||
from src.common.i18n import t
|
||
from .config_base import ConfigBase, Field
|
||
|
||
|
||
class OpenAICompatibleAuthType(str, Enum):
|
||
"""OpenAI 兼容接口的鉴权方式。"""
|
||
|
||
BEARER = "bearer"
|
||
HEADER = "header"
|
||
QUERY = "query"
|
||
NONE = "none"
|
||
|
||
|
||
class ReasoningParseMode(str, Enum):
|
||
"""推理内容解析策略。"""
|
||
|
||
AUTO = "auto"
|
||
NATIVE = "native"
|
||
THINK_TAG = "think_tag"
|
||
NONE = "none"
|
||
|
||
|
||
class ToolArgumentParseMode(str, Enum):
|
||
"""工具调用参数的解析策略。"""
|
||
|
||
AUTO = "auto"
|
||
STRICT = "strict"
|
||
REPAIR = "repair"
|
||
DOUBLE_DECODE = "double_decode"
|
||
|
||
|
||
class APIProvider(ConfigBase):
|
||
"""API提供商配置类"""
|
||
|
||
name: str = Field(
|
||
default="",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "tag",
|
||
},
|
||
)
|
||
"""API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
|
||
|
||
base_url: str = Field(
|
||
default="",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "link",
|
||
},
|
||
)
|
||
"""API服务商的BaseURL"""
|
||
|
||
api_key: str = Field(
|
||
default_factory=str,
|
||
repr=False,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "key",
|
||
},
|
||
)
|
||
"""API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。"""
|
||
|
||
client_type: str = Field(
|
||
default="openai",
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "settings",
|
||
},
|
||
)
|
||
"""客户端类型 (可选: openai/google, 默认为openai)"""
|
||
|
||
auth_type: str = Field(
|
||
default=OpenAICompatibleAuthType.BEARER.value,
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "shield",
|
||
},
|
||
)
|
||
"""OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。"""
|
||
|
||
auth_header_name: str = Field(
|
||
default="Authorization",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "header",
|
||
},
|
||
)
|
||
"""当 `auth_type` 为 `header` 时使用的请求头名称。"""
|
||
|
||
auth_header_prefix: str = Field(
|
||
default="Bearer",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "shield-check",
|
||
},
|
||
)
|
||
"""当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。"""
|
||
|
||
auth_query_name: str = Field(
|
||
default="api_key",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "link",
|
||
},
|
||
)
|
||
"""当 `auth_type` 为 `query` 时使用的查询参数名称。"""
|
||
|
||
default_headers: dict[str, str] = Field(
|
||
default_factory=dict,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "header",
|
||
},
|
||
)
|
||
"""所有请求默认附带的 HTTP Header。"""
|
||
|
||
default_query: dict[str, str] = Field(
|
||
default_factory=dict,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "list-filter",
|
||
},
|
||
)
|
||
"""所有请求默认附带的查询参数。"""
|
||
|
||
organization: str | None = Field(
|
||
default=None,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "building-2",
|
||
},
|
||
)
|
||
"""OpenAI 官方接口可选的 `organization`。"""
|
||
|
||
project: str | None = Field(
|
||
default=None,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "folder-kanban",
|
||
},
|
||
)
|
||
"""OpenAI 官方接口可选的 `project`。"""
|
||
|
||
model_list_endpoint: str = Field(
|
||
default="/models",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "list",
|
||
},
|
||
)
|
||
"""模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。"""
|
||
|
||
reasoning_parse_mode: str = Field(
|
||
default=ReasoningParseMode.AUTO.value,
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "brain",
|
||
},
|
||
)
|
||
"""推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。"""
|
||
|
||
tool_argument_parse_mode: str = Field(
|
||
default=ToolArgumentParseMode.AUTO.value,
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "braces",
|
||
},
|
||
)
|
||
"""工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。"""
|
||
|
||
max_retry: int = Field(
|
||
default=2,
|
||
ge=0,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "repeat",
|
||
},
|
||
)
|
||
"""最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
|
||
|
||
timeout: int = Field(
|
||
default=10,
|
||
ge=1,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "clock",
|
||
"step": 1,
|
||
},
|
||
)
|
||
"""API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
|
||
|
||
retry_interval: int = Field(
|
||
default=10,
|
||
ge=1,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "timer",
|
||
"step": 1,
|
||
},
|
||
)
|
||
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
|
||
|
||
def model_post_init(self, context: Any = None) -> None:
|
||
"""执行 API 提供商配置的后置校验。
|
||
|
||
Args:
|
||
context: Pydantic 传入的上下文对象。
|
||
|
||
Raises:
|
||
ValueError: 当配置项缺失或组合不合法时抛出。
|
||
"""
|
||
if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key:
|
||
raise ValueError(t("config.api_key_empty"))
|
||
if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url
|
||
raise ValueError(t("config.api_base_url_empty"))
|
||
if not self.name:
|
||
raise ValueError(t("config.api_provider_name_empty"))
|
||
if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip():
|
||
raise ValueError("当 auth_type=header 时,auth_header_name 不能为空")
|
||
if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip():
|
||
raise ValueError("当 auth_type=query 时,auth_query_name 不能为空")
|
||
super().model_post_init(context)
|
||
|
||
|
||
class ModelInfo(ConfigBase):
|
||
"""单个模型信息配置类"""
|
||
|
||
_validate_any: bool = False
|
||
suppress_any_warning: bool = True
|
||
|
||
model_identifier: str = Field(
|
||
default="",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "package",
|
||
},
|
||
)
|
||
"""模型标识符 (API服务商提供的模型标识符)"""
|
||
|
||
name: str = Field(
|
||
default="",
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "tag",
|
||
},
|
||
)
|
||
"""模型名称 (可随意命名, 在models中需使用这个命名)"""
|
||
|
||
api_provider: str = Field(
|
||
default="",
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "link",
|
||
},
|
||
)
|
||
"""API服务商名称 (对应在api_providers中配置的服务商名称)"""
|
||
|
||
price_in: float = Field(
|
||
default=0.0,
|
||
ge=0,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "dollar-sign",
|
||
"step": 0.001,
|
||
},
|
||
)
|
||
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||
|
||
price_out: float = Field(
|
||
default=0.0,
|
||
ge=0,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "dollar-sign",
|
||
"step": 0.001,
|
||
},
|
||
)
|
||
"""输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||
|
||
temperature: float | None = Field(
|
||
default=None,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "thermometer",
|
||
},
|
||
)
|
||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||
|
||
max_tokens: int | None = Field(
|
||
default=None,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "layers",
|
||
},
|
||
)
|
||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||
|
||
force_stream_mode: bool = Field(
|
||
default=False,
|
||
json_schema_extra={
|
||
"x-widget": "switch",
|
||
"x-icon": "zap",
|
||
},
|
||
)
|
||
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
|
||
|
||
extra_params: dict[str, Any] = Field(
|
||
default_factory=dict,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "sliders",
|
||
},
|
||
)
|
||
"""额外参数 (用于API调用时的额外配置)"""
|
||
|
||
def model_post_init(self, context: Any = None):
|
||
if not self.model_identifier:
|
||
raise ValueError(t("config.model_identifier_empty_generic"))
|
||
if not self.name:
|
||
raise ValueError(t("config.model_name_empty"))
|
||
if not self.api_provider:
|
||
raise ValueError(t("config.model_api_provider_empty"))
|
||
return super().model_post_init(context)
|
||
|
||
|
||
class TaskConfig(ConfigBase):
|
||
"""任务配置类"""
|
||
|
||
model_list: list[str] = Field(
|
||
default_factory=list,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "list",
|
||
},
|
||
)
|
||
"""使用的模型列表, 每个元素对应上面的模型名称(name)"""
|
||
|
||
max_tokens: int = Field(
|
||
default=1024,
|
||
ge=1,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "layers",
|
||
"step": 1,
|
||
},
|
||
)
|
||
"""任务最大输出token数"""
|
||
|
||
temperature: float = Field(
|
||
default=0.3,
|
||
ge=0,
|
||
le=2,
|
||
json_schema_extra={
|
||
"x-widget": "slider",
|
||
"x-icon": "thermometer",
|
||
"step": 0.1,
|
||
},
|
||
)
|
||
"""模型温度"""
|
||
|
||
slow_threshold: float = Field(
|
||
default=15.0,
|
||
ge=0,
|
||
json_schema_extra={
|
||
"x-widget": "input",
|
||
"x-icon": "alert-circle",
|
||
"step": 0.1,
|
||
},
|
||
)
|
||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||
|
||
selection_strategy: str = Field(
|
||
default="balance",
|
||
json_schema_extra={
|
||
"x-widget": "select",
|
||
"x-icon": "shuffle",
|
||
},
|
||
)
|
||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
||
|
||
|
||
class ModelTaskConfig(ConfigBase):
|
||
"""模型配置类"""
|
||
|
||
utils: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "wrench",
|
||
},
|
||
)
|
||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||
|
||
replyer: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "message-square",
|
||
},
|
||
)
|
||
"""首要回复模型配置, 还用于表达器和表达方式学习"""
|
||
|
||
planner: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "map",
|
||
},
|
||
)
|
||
"""规划模型配置"""
|
||
|
||
vlm: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "image",
|
||
},
|
||
)
|
||
"""视觉模型配置"""
|
||
|
||
voice: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "volume-2",
|
||
},
|
||
)
|
||
"""语音识别模型配置"""
|
||
|
||
embedding: TaskConfig = Field(
|
||
default_factory=TaskConfig,
|
||
json_schema_extra={
|
||
"x-widget": "custom",
|
||
"x-icon": "database",
|
||
},
|
||
)
|
||
"""嵌入模型配置""" |