from enum import Enum from typing import Any from src.common.i18n import t from .config_base import ConfigBase, Field class OpenAICompatibleAuthType(str, Enum): """OpenAI 兼容接口的鉴权方式。""" BEARER = "bearer" HEADER = "header" QUERY = "query" NONE = "none" class ReasoningParseMode(str, Enum): """推理内容解析策略。""" AUTO = "auto" NATIVE = "native" THINK_TAG = "think_tag" NONE = "none" class ToolArgumentParseMode(str, Enum): """工具调用参数的解析策略。""" AUTO = "auto" STRICT = "strict" REPAIR = "repair" DOUBLE_DECODE = "double_decode" class APIProvider(ConfigBase): """API提供商配置类""" name: str = Field( default="", json_schema_extra={ "x-widget": "input", "x-icon": "tag", }, ) """API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)""" base_url: str = Field( default="", json_schema_extra={ "x-widget": "input", "x-icon": "link", }, ) """API服务商的BaseURL""" api_key: str = Field( default_factory=str, repr=False, json_schema_extra={ "x-widget": "input", "x-icon": "key", }, ) """API密钥。对于不需要鉴权的兼容端点,可将 `auth_type` 设为 `none`。""" client_type: str = Field( default="openai", json_schema_extra={ "x-widget": "select", "x-icon": "settings", }, ) """客户端类型 (可选: openai/google, 默认为openai)""" auth_type: str = Field( default=OpenAICompatibleAuthType.BEARER.value, json_schema_extra={ "x-widget": "select", "x-icon": "shield", }, ) """OpenAI 兼容接口的鉴权方式。可选值:`bearer`、`header`、`query`、`none`。""" auth_header_name: str = Field( default="Authorization", json_schema_extra={ "x-widget": "input", "x-icon": "header", }, ) """当 `auth_type` 为 `header` 时使用的请求头名称。""" auth_header_prefix: str = Field( default="Bearer", json_schema_extra={ "x-widget": "input", "x-icon": "shield-check", }, ) """当 `auth_type` 为 `header` 时使用的请求头前缀。留空表示直接发送原始密钥。""" auth_query_name: str = Field( default="api_key", json_schema_extra={ "x-widget": "input", "x-icon": "link", }, ) """当 `auth_type` 为 `query` 时使用的查询参数名称。""" default_headers: dict[str, str] = Field( default_factory=dict, json_schema_extra={ "x-widget": "custom", "x-icon": "header", }, ) """所有请求默认附带的 HTTP Header。""" default_query: dict[str, str] = Field( default_factory=dict, json_schema_extra={ "x-widget": "custom", "x-icon": "list-filter", }, ) """所有请求默认附带的查询参数。""" organization: str | None = Field( default=None, json_schema_extra={ "x-widget": "input", "x-icon": "building-2", }, ) """OpenAI 官方接口可选的 `organization`。""" project: str | None = Field( default=None, json_schema_extra={ "x-widget": "input", "x-icon": "folder-kanban", }, ) """OpenAI 官方接口可选的 `project`。""" model_list_endpoint: str = Field( default="/models", json_schema_extra={ "x-widget": "input", "x-icon": "list", }, ) """模型列表端点路径。适用于 OpenAI 兼容接口的探测与管理。""" reasoning_parse_mode: str = Field( default=ReasoningParseMode.AUTO.value, json_schema_extra={ "x-widget": "select", "x-icon": "brain", }, ) """推理内容解析模式。可选值:`auto`、`native`、`think_tag`、`none`。""" tool_argument_parse_mode: str = Field( default=ToolArgumentParseMode.AUTO.value, json_schema_extra={ "x-widget": "select", "x-icon": "braces", }, ) """工具参数解析模式。可选值:`auto`、`strict`、`repair`、`double_decode`。""" max_retry: int = Field( default=2, ge=0, json_schema_extra={ "x-widget": "input", "x-icon": "repeat", }, ) """最大重试次数 (单个模型API调用失败, 最多重试的次数)""" timeout: int = Field( default=10, ge=1, json_schema_extra={ "x-widget": "input", "x-icon": "clock", "step": 1, }, ) """API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)""" retry_interval: int = Field( default=10, ge=1, json_schema_extra={ "x-widget": "input", "x-icon": "timer", "step": 1, }, ) """重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)""" def model_post_init(self, context: Any = None) -> None: """执行 API 提供商配置的后置校验。 Args: context: Pydantic 传入的上下文对象。 Raises: ValueError: 当配置项缺失或组合不合法时抛出。 """ if self.auth_type != OpenAICompatibleAuthType.NONE and not self.api_key: raise ValueError(t("config.api_key_empty")) if not self.base_url and self.client_type != "gemini": # TODO: 允许gemini使用base_url raise ValueError(t("config.api_base_url_empty")) if not self.name: raise ValueError(t("config.api_provider_name_empty")) if self.auth_type == OpenAICompatibleAuthType.HEADER and not self.auth_header_name.strip(): raise ValueError("当 auth_type=header 时,auth_header_name 不能为空") if self.auth_type == OpenAICompatibleAuthType.QUERY and not self.auth_query_name.strip(): raise ValueError("当 auth_type=query 时,auth_query_name 不能为空") super().model_post_init(context) class ModelInfo(ConfigBase): """单个模型信息配置类""" _validate_any: bool = False suppress_any_warning: bool = True model_identifier: str = Field( default="", json_schema_extra={ "x-widget": "input", "x-icon": "package", }, ) """模型标识符 (API服务商提供的模型标识符)""" name: str = Field( default="", json_schema_extra={ "x-widget": "input", "x-icon": "tag", }, ) """模型名称 (可随意命名, 在models中需使用这个命名)""" api_provider: str = Field( default="", json_schema_extra={ "x-widget": "select", "x-icon": "link", }, ) """API服务商名称 (对应在api_providers中配置的服务商名称)""" price_in: float = Field( default=0.0, ge=0, json_schema_extra={ "x-widget": "input", "x-icon": "dollar-sign", "step": 0.001, }, ) """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" price_out: float = Field( default=0.0, ge=0, json_schema_extra={ "x-widget": "input", "x-icon": "dollar-sign", "step": 0.001, }, ) """输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" temperature: float | None = Field( default=None, json_schema_extra={ "x-widget": "input", "x-icon": "thermometer", }, ) """模型级别温度(可选),会覆盖任务配置中的温度""" max_tokens: int | None = Field( default=None, json_schema_extra={ "x-widget": "input", "x-icon": "layers", }, ) """模型级别最大token数(可选),会覆盖任务配置中的max_tokens""" force_stream_mode: bool = Field( default=False, json_schema_extra={ "x-widget": "switch", "x-icon": "zap", }, ) """强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)""" extra_params: dict[str, Any] = Field( default_factory=dict, json_schema_extra={ "x-widget": "custom", "x-icon": "sliders", }, ) """额外参数 (用于API调用时的额外配置)""" def model_post_init(self, context: Any = None): if not self.model_identifier: raise ValueError(t("config.model_identifier_empty_generic")) if not self.name: raise ValueError(t("config.model_name_empty")) if not self.api_provider: raise ValueError(t("config.model_api_provider_empty")) return super().model_post_init(context) class TaskConfig(ConfigBase): """任务配置类""" model_list: list[str] = Field( default_factory=list, json_schema_extra={ "x-widget": "custom", "x-icon": "list", }, ) """使用的模型列表, 每个元素对应上面的模型名称(name)""" max_tokens: int = Field( default=1024, ge=1, json_schema_extra={ "x-widget": "input", "x-icon": "layers", "step": 1, }, ) """任务最大输出token数""" temperature: float = Field( default=0.3, ge=0, le=2, json_schema_extra={ "x-widget": "slider", "x-icon": "thermometer", "step": 0.1, }, ) """模型温度""" slow_threshold: float = Field( default=15.0, ge=0, json_schema_extra={ "x-widget": "input", "x-icon": "alert-circle", "step": 0.1, }, ) """慢请求阈值(秒),超过此值会输出警告日志""" selection_strategy: str = Field( default="balance", json_schema_extra={ "x-widget": "select", "x-icon": "shuffle", }, ) """模型选择策略:balance(负载均衡)或 random(随机选择)""" class ModelTaskConfig(ConfigBase): """模型配置类""" utils: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "wrench", }, ) """组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型""" replyer: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "message-square", }, ) """首要回复模型配置, 还用于表达器和表达方式学习""" planner: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "map", }, ) """规划模型配置""" vlm: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "image", }, ) """视觉模型配置""" voice: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "volume-2", }, ) """语音识别模型配置""" embedding: TaskConfig = Field( default_factory=TaskConfig, json_schema_extra={ "x-widget": "custom", "x-icon": "database", }, ) """嵌入模型配置"""